diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a4c441805..19ce2294cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ This project adheres to [Semantic Versioning](https://semver.org/). - Fixed the function signature of `recvmmsg`, potentially causing UB ([#2119](https://github.com/nix-rust/nix/issues/2119)) +### Added + +- Added mutex interface. + ([#1950](https://github.com/nix-rust/nix/pull/1950)) ### Changed diff --git a/src/sys/pthread.rs b/src/sys/pthread.rs index 6bad03a4d4..5b076b6d3f 100644 --- a/src/sys/pthread.rs +++ b/src/sys/pthread.rs @@ -5,6 +5,11 @@ use crate::errno::Errno; #[cfg(not(target_os = "redox"))] use crate::Result; use libc::{self, pthread_t}; +#[cfg(not(target_os = "redox"))] +use libc::c_int; + +#[cfg(target_os = "linux")] +use std::cell::UnsafeCell; /// Identifies an individual thread. pub type Pthread = pthread_t; @@ -34,10 +39,351 @@ pub fn pthread_kill(thread: Pthread, signal: T) -> Result<()> where T: Into> { let sig = match signal.into() { - Some(s) => s as libc::c_int, + Some(s) => s as c_int, None => 0, }; let res = unsafe { libc::pthread_kill(thread, sig) }; Errno::result(res).map(drop) } } + +/// Mutex protocol. +#[cfg(target_os = "linux")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum Protocol { + /// [`libc::PTHREAD_PRIO_NONE`] + None = libc::PTHREAD_PRIO_NONE, + /// [`libc::PTHREAD_PRIO_INHERIT`] + Inherit = libc::PTHREAD_PRIO_INHERIT, + /// [`libc::PTHREAD_PRIO_PROTECT`] + Protect = libc::PTHREAD_PRIO_PROTECT +} +#[cfg(target_os = "linux")] +impl From for Protocol { + fn from(x: i32) -> Self { + match x { + libc::PTHREAD_PRIO_NONE => Self::None, + libc::PTHREAD_PRIO_INHERIT => Self::Inherit, + libc::PTHREAD_PRIO_PROTECT => Self::Protect, + _ => unreachable!() + } + } +} + +/// Mutex attributes. +#[cfg(target_os = "linux")] +#[derive(Debug)] +#[repr(transparent)] +pub struct MutexAttr(libc::pthread_mutexattr_t); + +#[cfg(target_os = "linux")] +impl MutexAttr { + /// Wraps [`libc::pthread_mutexattr_init`]. + pub fn new() -> Result { + let attr = unsafe { + let mut uninit = std::mem::MaybeUninit::::uninit(); + Errno::result(libc::pthread_mutexattr_init(uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(Self(attr)) + } + + /// Wraps [`libc::pthread_mutexattr_getpshared`]. + pub fn get_shared(&self) -> Result { + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::uninit(); + Errno::result(libc::pthread_mutexattr_getpshared(&self.0,uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(init == libc::PTHREAD_PROCESS_SHARED) + } + /// Wraps [`libc::pthread_mutexattr_setpshared`]. + pub fn set_shared(&mut self, shared: bool) -> Result<()> { + let shared = if shared { libc::PTHREAD_PROCESS_SHARED} else { libc::PTHREAD_PROCESS_PRIVATE }; + unsafe { + Errno::result(libc::pthread_mutexattr_setpshared(&mut self.0,shared)).map(drop) + } + } + /// Wraps [`libc::pthread_mutexattr_getrobust`]. + pub fn get_robust(&self) -> Result { + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::uninit(); + Errno::result(libc::pthread_mutexattr_getrobust(&self.0,uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(init == libc::PTHREAD_MUTEX_ROBUST) + } + /// Wraps [`libc::pthread_mutexattr_setrobust`]. + pub fn set_robust(&mut self, robust: bool) -> Result<()> { + let robust = if robust { libc::PTHREAD_MUTEX_ROBUST} else { libc::PTHREAD_MUTEX_STALLED }; + unsafe { + Errno::result(libc::pthread_mutexattr_setrobust(&mut self.0,robust)).map(drop) + } + } + /// Wraps [`libc::pthread_mutexattr_getprotocol`]. + pub fn get_protocol(&self) -> Result { + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::uninit(); + Errno::result(libc::pthread_mutexattr_getprotocol(&self.0,uninit.as_mut_ptr()))?; + uninit.assume_init() + }; + Ok(Protocol::from(init)) + } + /// Wraps [`libc::pthread_mutexattr_setprotocol`]. + pub fn set_protocol(&mut self, protocol: Protocol) -> Result<()> { + unsafe { + Errno::result(libc::pthread_mutexattr_setprotocol(&mut self.0,protocol as i32)).map(drop) + } + } +} + +#[cfg(target_os = "linux")] +impl std::ops::Drop for MutexAttr { + /// Wraps [`libc::pthread_mutexattr_destroy`]. + fn drop(&mut self) { + unsafe { + Errno::result(libc::pthread_mutexattr_destroy(&mut self.0)).unwrap(); + } + } +} + +/// Pthread Mutex. +/// +/// ### Getting started +/// ``` +/// # use std::{ +/// # sync::Arc, +/// # time::{Instant, Duration}, +/// # thread::{sleep, spawn}, +/// # mem::size_of, +/// # num::NonZeroUsize, +/// # os::unix::io::OwnedFd +/// # }; +/// # use nix::{ +/// # sys::{pthread::{Mutex, MutexAttr}, mman::{mmap, MapFlags, ProtFlags}}, +/// # unistd::{fork,ForkResult}, +/// # }; +/// const TIMEOUT: Duration = Duration::from_millis(500); +/// const DELTA: Duration = Duration::from_millis(100); +/// # fn main() -> nix::Result<()> { +/// let mutex = Mutex::new(None)?; +/// +/// // The mutex is initialized unlocked, so an attempt to unlock it should +/// // return immediately. +/// assert_eq!(unsafe { mutex.unlock() }, Ok(())); +/// // The mutex is unlocked, so `try_lock` will lock. +/// let guard = mutex.try_lock()?.unwrap(); +/// // Unlock the mutex. +/// drop(guard); +/// // The mutex is unlocked, so `lock` will lock and exit immediately. +/// let guard = mutex.lock()?; +/// // Unlock the mutex. +/// guard.try_unlock()?; +/// # Ok(()) +/// # } +/// ``` +/// +/// ### Multi-thread +/// +/// ``` +/// # use std::{ +/// # sync::Arc, +/// # time::{Instant, Duration}, +/// # thread::{sleep, spawn}, +/// # mem::size_of, +/// # num::NonZeroUsize, +/// # os::unix::io::OwnedFd +/// # }; +/// # use nix::{ +/// # sys::{pthread::{Mutex, MutexAttr}, mman::{mmap, MapFlags, ProtFlags}}, +/// # unistd::{fork,ForkResult}, +/// # }; +/// const TIMEOUT: Duration = Duration::from_millis(500); +/// const DELTA: Duration = Duration::from_millis(100); +/// # fn main() -> nix::Result<()> { +/// let mutex = Mutex::new(None)?; +/// let mutex_arc = Arc::new(mutex); +/// let mutex_clone = mutex_arc.clone(); +/// let instant = Instant::now(); +/// let handle = spawn(move || -> nix::Result<(),> { +/// let guard = mutex_clone.lock()?; +/// sleep(TIMEOUT); +/// guard.try_unlock()?; +/// Ok(()) +/// }); +/// sleep(DELTA); +/// let guard = mutex_arc.lock()?; +/// assert!(instant.elapsed() > TIMEOUT && instant.elapsed() < TIMEOUT + DELTA); +/// assert_eq!(handle.join().unwrap(), Ok(())); +/// # Ok(()) +/// # } +/// ``` +/// +/// ### Multi-process +/// +/// ``` +/// # use std::{ +/// # sync::Arc, +/// # time::{Instant, Duration}, +/// # thread::{sleep, spawn}, +/// # mem::{size_of, MaybeUninit}, +/// # num::NonZeroUsize, +/// # os::unix::io::OwnedFd +/// # }; +/// # use nix::{ +/// # sys::{pthread::{Mutex, MutexAttr}, mman::{mmap, MapFlags, ProtFlags}}, +/// # unistd::{fork,ForkResult}, +/// # }; +/// const TIMEOUT: Duration = Duration::from_millis(500); +/// const DELTA: Duration = Duration::from_millis(100); +/// # fn main() -> nix::Result<()> { +/// let shared_memory = unsafe { mmap::( +/// None, +/// NonZeroUsize::new_unchecked(size_of::()), +/// ProtFlags::PROT_WRITE | ProtFlags::PROT_READ, +/// MapFlags::MAP_SHARED | MapFlags::MAP_ANONYMOUS, +/// None, +/// 0 +/// )? }; +/// +/// let mutex = unsafe { +/// let mutex_ptr = shared_memory.cast::(); +/// +/// // A mutex must be initialized. +/// // By default mutex's are process private, so we also need to initialize with the +/// // `MutexAttr` with shared. +/// let mut mutex_attr = MutexAttr::new()?; +/// mutex_attr.set_shared(true)?; +/// Mutex::init(mutex_ptr, Some(mutex_attr))?; +/// +/// &*mutex_ptr +/// }; +/// +/// match unsafe { fork()? } { +/// ForkResult::Parent { child } => { +/// let guard = mutex.lock()?; +/// sleep(TIMEOUT); +/// guard.try_unlock(); +/// // Wait for child process to exit +/// unsafe { +/// assert_eq!(libc::waitpid(child.as_raw(), std::ptr::null_mut(),0), child.as_raw()); +/// } +/// }, +/// ForkResult::Child => { +/// let now = Instant::now(); +/// sleep(DELTA); +/// mutex.lock()?; +/// assert!(now.elapsed() > TIMEOUT && now.elapsed() < TIMEOUT + DELTA); +/// } +/// } +/// +/// # Ok(()) +/// # } +/// ``` +#[cfg(target_os = "linux")] +#[derive(Debug)] +#[repr(transparent)] +pub struct Mutex(UnsafeCell); + +#[cfg(target_os = "linux")] +impl Mutex { + /// Wraps [`libc::pthread_mutex_init`]. + /// + /// # Safety + /// + /// Requires `mutex` contains `Mutex`. + pub unsafe fn init(mutex: *mut Mutex, attr: Option) -> Result<()> { + let attr = match attr { + Some(mut x) => &mut x.0, + None => std::ptr::null_mut() + }; + Errno::result(libc::pthread_mutex_init((*mutex).0.get(),attr))?; + Ok(()) + } + /// Wraps [`libc::pthread_mutex_init`]. + pub fn new(attr: Option) -> Result { + let attr = match attr { + Some(mut x) => &mut x.0, + None => std::ptr::null_mut() + }; + let init = unsafe { + let mut uninit = std::mem::MaybeUninit::::uninit(); + Errno::result(libc::pthread_mutex_init(uninit.as_mut_ptr(),attr))?; + uninit.assume_init() + }; + Ok(Self(UnsafeCell::new(init))) + } + /// Wraps [`libc::pthread_mutex_lock`]. + /// + /// + pub fn lock(&self) -> Result> { + unsafe { + Errno::result(libc::pthread_mutex_lock(self.0.get())).map(|_| MutexGuard(self)) + } + } + /// Wraps [`libc::pthread_mutex_trylock`]. + /// + /// + pub fn try_lock(&self) -> Result>> { + unsafe { + match Errno::result(libc::pthread_mutex_trylock(self.0.get())) { + Ok(_) => Ok(Some(MutexGuard(self))), + Err(Errno::EBUSY) => Ok(None), + Err(err) => Err(err) + } + + } + } + /// Wraps [`libc::pthread_mutex_unlock`]. + /// + /// + /// + /// Prefer unlocking by dropping the [`MutexGuard`] returned by [`Mutex::lock`] or [`Mutex::try_lock`]. + /// + /// # Safety + /// + /// Results in UB if not called from the thread that locked the mutex. + pub unsafe fn unlock(&self) -> Result<()> { + unsafe { + Errno::result(libc::pthread_mutex_unlock(self.0.get())).map(drop) + } + } +} + +#[cfg(target_os = "linux")] +unsafe impl Sync for Mutex {} + +#[cfg(target_os = "linux")] +impl std::ops::Drop for Mutex { + /// Wraps [`libc::pthread_mutex_destroy`]. + fn drop(&mut self) { + unsafe { + Errno::result(libc::pthread_mutex_destroy(self.0.get())).unwrap(); + } + } +} + +/// Mutex guard to prevent unlocking a mutex from a different thread than the thread that locked it. +#[cfg(target_os = "linux")] +#[derive(Debug)] +pub struct MutexGuard<'a>(&'a Mutex); + +#[cfg(target_os = "linux")] +impl MutexGuard<'_> { + /// Calls [`Mutex::unlock`]. + pub fn try_unlock(self) -> Result<()> { + // Prevent calling `Self::Drop` which would attempt to unlock twice. + unsafe { std::mem::ManuallyDrop::new(self).0.unlock() } + } +} + +#[cfg(target_os = "linux")] +impl std::ops::Drop for MutexGuard<'_> { + /// Calls [`Mutex::unlock`]. + fn drop(&mut self) { + unsafe { + self.0.unlock().unwrap(); + } + } +} \ No newline at end of file diff --git a/test/sys/test_pthread.rs b/test/sys/test_pthread.rs index ce048bae60..455bb65829 100644 --- a/test/sys/test_pthread.rs +++ b/test/sys/test_pthread.rs @@ -20,3 +20,101 @@ fn test_pthread_kill_none() { pthread_kill(pthread_self(), None) .expect("Should be able to send signal to my thread."); } + +#[test] +#[cfg(target_os = "linux")] +fn test_pthread_mutex_wrapper() { + use nix::{ + sys::{ + mman::{mmap, MapFlags, ProtFlags}, + pthread::{Mutex, MutexAttr}, + }, + unistd::{fork, ForkResult}, + }; + use std::{ + cell::UnsafeCell, mem::size_of, num::NonZeroUsize, + os::unix::io::OwnedFd, + }; + struct MutexWrapper { + lock: Mutex, + data: UnsafeCell, + } + impl MutexWrapper { + fn add(&self) { + let guard = self.lock.lock().unwrap(); + unsafe { *self.data.get() += 1 }; + // guard.unlock().unwrap(); + guard.try_unlock().unwrap(); + } + } + unsafe impl Sync for MutexWrapper {} + + /// Number of forks to spawn that mutate the data, will spawn `2^FORKS` processes. + const FORKS: usize = 3; + /// Number of threads each process spawns that mutate the data, will spawn `2^FORKS * THREADS` threads. + const THREADS: usize = 20; + /// Number of iterations each thread mutates the data, will perform `2^FORKS * THREADS * ITERATIONS` iterations. + const ITERATIONS: usize = 100_000; + + let mut mutex_attr = MutexAttr::new().unwrap(); + mutex_attr.set_shared(true).unwrap(); + + let mutex_wrapper = unsafe { + let ptr = mmap::( + None, + NonZeroUsize::new_unchecked(size_of::()), + ProtFlags::PROT_WRITE | ProtFlags::PROT_READ, + MapFlags::MAP_SHARED | MapFlags::MAP_ANONYMOUS, + None, + 0, + ) + .unwrap(); + let mutex_wrapper_ptr = ptr.cast(); + std::ptr::write( + mutex_wrapper_ptr, + MutexWrapper { + lock: Mutex::new(Some(mutex_attr)).unwrap(), + data: UnsafeCell::new(0), + }, + ); + &*mutex_wrapper_ptr + }; + + let fork_results = (0..FORKS) + .map(|_| unsafe { fork().unwrap() }) + .collect::>(); + + let handles = (0..THREADS) + .map(|_| { + std::thread::spawn(|| { + for _ in 0..ITERATIONS { + mutex_wrapper.add(); + } + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap(); + } + + // The root process will be the parent in all its fork results. + let mut root = true; + for fork_result in fork_results { + if let ForkResult::Parent { child } = fork_result { + unsafe { + assert_eq!( + libc::waitpid(child.as_raw(), std::ptr::null_mut(), 0), + child.as_raw() + ); + } + } else { + root = false; + } + } + if root { + let steps = + 2u128.pow(FORKS as u32) * (THREADS as u128) * (ITERATIONS as u128); + assert_eq!(unsafe { *mutex_wrapper.data.get() }, steps); + } +}