diff --git a/CHANGELOG.md b/CHANGELOG.md index 69f628d8a0..432ffffd08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ This project adheres to [Semantic Versioning](https://semver.org/). ### Added - Added `AT_EACCESS` to `AtFlags` on all platforms but android ([#1995](https://github.com/nix-rust/nix/pull/1995)) +- Added mutex interface. + ([#1950](https://github.com/nix-rust/nix/pull/1950)) - Add `PF_ROUTE` to `SockType` on macOS, iOS, all of the BSDs, Fuchsia, Haiku, Illumos. ([#1867](https://github.com/nix-rust/nix/pull/1867)) - Added `nix::ucontext` module on `aarch64-unknown-linux-gnu`. diff --git a/src/sys/pthread.rs b/src/sys/pthread.rs index 6bad03a4d4..f6c4b4db13 100644 --- a/src/sys/pthread.rs +++ b/src/sys/pthread.rs @@ -5,6 +5,13 @@ use crate::errno::Errno; #[cfg(not(target_os = "redox"))] use crate::Result; use libc::{self, pthread_t}; +#[cfg(not(target_os = "redox"))] +use libc::c_int; + +#[cfg(target_os = "linux")] +use std::cell::UnsafeCell; +#[cfg(target_os = "linux")] +use std::default::Default; /// Identifies an individual thread. pub type Pthread = pthread_t; @@ -34,10 +41,285 @@ pub fn pthread_kill(thread: Pthread, signal: T) -> Result<()> where T: Into> { let sig = match signal.into() { - Some(s) => s as libc::c_int, + Some(s) => s as c_int, None => 0, }; let res = unsafe { libc::pthread_kill(thread, sig) }; Errno::result(res).map(drop) } } + +/// Mutex protocol. +#[cfg(target_os = "linux")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum Protocol { + /// [`libc::PTHREAD_PRIO_NONE`] + None = libc::PTHREAD_PRIO_NONE, + /// [`libc::PTHREAD_PRIO_INHERIT`] + Inherit = libc::PTHREAD_PRIO_INHERIT, + /// [`libc::PTHREAD_PRIO_PROTECT`] + Protect = libc::PTHREAD_PRIO_PROTECT +} +#[cfg(target_os = "linux")] +impl From for Protocol { + fn from(x: i32) -> Self { + match x { + libc::PTHREAD_PRIO_NONE => Self::None, + libc::PTHREAD_PRIO_INHERIT => Self::Inherit, + libc::PTHREAD_PRIO_PROTECT => Self::Protect, + _ => unreachable!() + } + } +} + +/// Mutex attributes. +#[cfg(target_os = "linux")] +#[derive(Debug)] +pub struct MutexAttr(libc::pthread_mutexattr_t); + +#[cfg(target_os = "linux")] +impl MutexAttr { + /// Wraps [`libc::pthread_mutexattr_init`]. + pub fn new() -> Result { + let attr = unsafe { + let mut uninit = std::mem::MaybeUninit::::uninit(); + Errno::result(libc::pthread_mutexattr_init(uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(Self(attr)) + } + /// Wraps [`libc::pthread_mutexattr_getpshared`]. + pub fn get_shared(&self) -> Result { + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::uninit(); + Errno::result(libc::pthread_mutexattr_getpshared(&self.0,uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(init == libc::PTHREAD_PROCESS_SHARED) + } + /// Wraps [`libc::pthread_mutexattr_setpshared`]. + pub fn set_shared(&mut self, shared: bool) -> Result<()> { + let shared = if shared { libc::PTHREAD_PROCESS_SHARED} else { libc::PTHREAD_PROCESS_PRIVATE }; + unsafe { + Errno::result(libc::pthread_mutexattr_setpshared(&mut self.0,shared)).map(drop) + } + } + /// Wraps [`libc::pthread_mutexattr_getrobust`]. + pub fn get_robust(&self) -> Result { + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::uninit(); + Errno::result(libc::pthread_mutexattr_getrobust(&self.0,uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(init == libc::PTHREAD_MUTEX_ROBUST) + } + /// Wraps [`libc::pthread_mutexattr_setrobust`]. + pub fn set_robust(&mut self, robust: bool) -> Result<()> { + let robust = if robust { libc::PTHREAD_MUTEX_ROBUST} else { libc::PTHREAD_MUTEX_STALLED }; + unsafe { + Errno::result(libc::pthread_mutexattr_setrobust(&mut self.0,robust)).map(drop) + } + } + /// Wraps [`libc::pthread_mutexattr_getprotocol`]. + pub fn get_protocol(&self) -> Result { + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::uninit(); + Errno::result(libc::pthread_mutexattr_getprotocol(&self.0,uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(Protocol::from(init)) + } + /// Wraps [`libc::pthread_mutexattr_setprotocol`]. + pub fn set_protocol(&mut self, protocol: Protocol) -> Result<()> { + unsafe { + Errno::result(libc::pthread_mutexattr_setprotocol(&mut self.0,protocol as i32)).map(drop) + } + } +} + +#[cfg(target_os = "linux")] +impl Default for MutexAttr { + fn default() -> Self { + let mutex_attr = Self::new().unwrap(); + debug_assert_eq!(mutex_attr.get_shared(),Ok(true)); + mutex_attr + } +} +#[cfg(target_os = "linux")] +impl std::ops::Drop for MutexAttr { + /// Wraps [`libc::pthread_mutexattr_destroy`]. + fn drop(&mut self) { + unsafe { + Errno::result(libc::pthread_mutexattr_destroy(&mut self.0)).unwrap(); + } + } +} + +/// Mutex. +/// ``` +/// # use std::{ +/// # sync::Arc, +/// # time::{Instant, Duration}, +/// # thread::{sleep, spawn}, +/// # mem::size_of, +/// # num::NonZeroUsize, +/// # os::unix::io::OwnedFd +/// # }; +/// # use nix::{ +/// # sys::{pthread::{Mutex, MutexAttr}, mman::{mmap, MapFlags, ProtFlags}}, +/// # unistd::{fork,ForkResult}, +/// # }; +/// const TIMEOUT: Duration = Duration::from_millis(500); +/// const DELTA: Duration = Duration::from_millis(100); +/// # fn main() -> nix::Result<()> { +/// let mutex = Mutex::default(); +/// +/// // The mutex is initialized unlocked, so an attempt to unlock it should +/// // return immediately. +/// assert_eq!(mutex.unlock(), Ok(())); +/// // The mutex is unlocked, so `try_lock` will lock. +/// assert_eq!(mutex.try_lock(), Ok(true)); +/// // Unlock the mutex. +/// assert_eq!(mutex.unlock(), Ok(())); +/// // The mutex is unlocked, so `lock` will lock and exit immediately. +/// assert_eq!(mutex.lock(), Ok(())); +/// // Unlock the mutex. +/// assert_eq!(mutex.unlock(), Ok(())); +/// +/// // Test across threads +/// // ------------------------------------------------------------------------- +/// +/// let mutex = Arc::new(mutex); +/// let mutex_clone = mutex.clone(); +/// let instant = Instant::now(); +/// spawn(move || { +/// assert_eq!(mutex_clone.lock(), Ok(())); +/// sleep(TIMEOUT); +/// assert_eq!(mutex_clone.unlock(), Ok(())); +/// }); +/// sleep(DELTA); +/// assert_eq!(mutex.lock(), Ok(())); +/// assert!(instant.elapsed() > TIMEOUT && instant.elapsed() < TIMEOUT + DELTA); +/// +/// // Test across processes +/// // ------------------------------------------------------------------------- +/// +/// let shared_memory = unsafe { mmap::( +/// None, +/// NonZeroUsize::new_unchecked(size_of::()), +/// ProtFlags::PROT_WRITE | ProtFlags::PROT_READ, +/// MapFlags::MAP_SHARED | MapFlags::MAP_ANONYMOUS, +/// None, +/// 0 +/// )? }; +/// let mutex_ptr = shared_memory.cast::(); +/// let mutex = unsafe { &*mutex_ptr }; +/// +/// // If transmute or cast into a mutex, you must ensure it is initialized. +/// // By default mutex's are process private, so we need to initialize with the `MutexAttr` with +/// // shared. +/// let mut mutex_attr = MutexAttr::new()?; +/// mutex_attr.set_shared(true)?; +/// mutex.init(Some(mutex_attr))?; +/// +/// match unsafe { fork()? } { +/// ForkResult::Parent { child } => { +/// assert_eq!(mutex.lock(), Ok(())); +/// sleep(TIMEOUT); +/// assert_eq!(mutex.unlock(), Ok(())); +/// // Wait for child process to exit +/// unsafe { +/// assert_eq!(libc::waitpid(child.as_raw(),std::ptr::null_mut(),0),child.as_raw()); +/// } +/// }, +/// ForkResult::Child => { +/// let now = Instant::now(); +/// sleep(DELTA); +/// assert_eq!(mutex.lock(), Ok(())); +/// assert!(now.elapsed() > TIMEOUT && now.elapsed() < TIMEOUT + DELTA); +/// } +/// } +/// +/// # Ok(()) +/// # } +/// ``` +#[cfg(target_os = "linux")] +#[derive(Debug)] +pub struct Mutex(UnsafeCell); +#[cfg(target_os = "linux")] +impl Mutex { + /// Wraps [`libc::pthread_mutex_init`]. + pub fn init(&self, attr: Option) -> Result<()> { + let attr = match attr { + Some(mut x) => &mut x.0, + None => std::ptr::null_mut() + }; + unsafe { + Errno::result(libc::pthread_mutex_init(self.0.get(),attr))?; + } + Ok(()) + } + /// Wraps [`libc::pthread_mutex_init`]. + pub fn new(attr: Option) -> Result { + let attr = match attr { + Some(mut x) => &mut x.0, + None => std::ptr::null_mut() + }; + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::::uninit(); + Errno::result(libc::pthread_mutex_init(uninit.as_mut_ptr(),attr))?; + uninit.assume_init() + }; + Ok(Self(UnsafeCell::new(init))) + } + /// Wraps [`libc::pthread_mutex_lock`]. + /// + /// + pub fn lock(&self) -> Result<()> { + unsafe { + Errno::result(libc::pthread_mutex_lock(self.0.get())).map(drop) + } + } + /// Wraps [`libc::pthread_mutex_trylock`]. + /// + /// + pub fn try_lock(&self) -> Result { + unsafe { + match Errno::result(libc::pthread_mutex_trylock(self.0.get())) { + Ok(_) => Ok(true), + Err(Errno::EBUSY) => Ok(false), + Err(err) => Err(err) + } + + } + } + /// Wraps [`libc::pthread_mutex_unlock`]. + /// + /// + pub fn unlock(&self) -> Result<()> { + unsafe { + Errno::result(libc::pthread_mutex_unlock(self.0.get())).map(drop) + } + } +} + +#[cfg(target_os = "linux")] +unsafe impl Sync for Mutex {} + +#[cfg(target_os = "linux")] +impl Default for Mutex { + fn default() -> Self { + Self::new(Default::default()).unwrap() + } +} + +#[cfg(target_os = "linux")] +impl std::ops::Drop for Mutex { + /// Wraps [`libc::pthread_mutex_destroy`]. + fn drop(&mut self) { + unsafe { + Errno::result(libc::pthread_mutex_destroy(self.0.get())).unwrap(); + } + } +} \ No newline at end of file diff --git a/test/sys/test_pthread.rs b/test/sys/test_pthread.rs index ce048bae60..0215b6ee5e 100644 --- a/test/sys/test_pthread.rs +++ b/test/sys/test_pthread.rs @@ -20,3 +20,103 @@ fn test_pthread_kill_none() { pthread_kill(pthread_self(), None) .expect("Should be able to send signal to my thread."); } + +#[test] +#[cfg(target_os = "linux")] +fn test_pthread_mutex_wrapper() { + use nix::{ + sys::{ + mman::{mmap, MapFlags, ProtFlags}, + pthread::{Mutex, MutexAttr}, + }, + unistd::{fork, ForkResult}, + }; + use std::{mem::size_of, num::NonZeroUsize, os::unix::io::OwnedFd}; + struct MutexWrapper { + lock: Mutex, + data: u128, + } + impl MutexWrapper { + fn add(&mut self) { + self.lock.lock().unwrap(); + self.data += 1; + self.lock.unlock().unwrap(); + } + } + + /// Number of forks to spawn that mutate the data, 2^n processes. + const FORKS: usize = 3; + /// Number of threads each process spawns that mutate the data. + const THREADS: usize = 20; + /// Number of iterations each thread mutates the data. + const ITERATIONS: usize = 100_000; + + let mut mutex_attr = MutexAttr::new().unwrap(); + mutex_attr.set_shared(true).unwrap(); + + let mutex_wrapper = unsafe { + mmap::( + None, + NonZeroUsize::new_unchecked(size_of::()), + ProtFlags::PROT_WRITE | ProtFlags::PROT_READ, + MapFlags::MAP_SHARED | MapFlags::MAP_ANONYMOUS, + None, + 0, + ) + .unwrap() + .cast() + }; + + unsafe { + std::ptr::write( + mutex_wrapper, + MutexWrapper { + lock: Mutex::new(Some(mutex_attr)).unwrap(), + data: 0, + }, + ); + } + + let fork_results = (0..FORKS) + .map(|_| unsafe { fork().unwrap() }) + .collect::>(); + + let handles = (0..THREADS) + .map(|_| { + let bits = mutex_wrapper as usize; + std::thread::spawn(move || { + let ptr = bits as *mut MutexWrapper; + + let wrapper_ref = unsafe { &mut *ptr }; + for _ in 0..ITERATIONS { + wrapper_ref.add(); + } + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap(); + } + + // The root process will be the parent in all its fork results. + let mut root = true; + for fork_result in fork_results { + if let ForkResult::Parent { child } = fork_result { + unsafe { + assert_eq!( + libc::waitpid(child.as_raw(), std::ptr::null_mut(), 0), + child.as_raw() + ); + } + } else { + root = false; + } + } + if root { + let mutex = unsafe { &*mutex_wrapper }; + let steps = + 2u128.pow(FORKS as u32) * (THREADS as u128) * (ITERATIONS as u128); + assert_eq!(mutex.data, steps); + } +}