Skip to content

Commit

Permalink
feat: Safer epoll timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanWoollett-Light committed Nov 23, 2023
1 parent c1147c6 commit 0253dfa
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 238 deletions.
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ feature! {
#[allow(missing_docs)]
pub mod unistd;

#[cfg(any(feature = "poll", feature = "event"))]
mod timeout;
#[cfg(any(feature = "poll", feature = "event"))]
pub use timeout::*;

use std::ffi::{CStr, CString, OsStr};
use std::mem::MaybeUninit;
use std::os::unix::ffi::OsStrExt;
Expand Down
237 changes: 8 additions & 229 deletions src/poll.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Wait for events to trigger on specific file descriptors
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd};
use std::time::Duration;

use crate::errno::Errno;
use crate::Result;
use crate::Timeout;

/// This is a wrapper around `libc::pollfd`.
///
/// It's meant to be used as an argument to the [`poll`](fn.poll.html) and
Expand All @@ -27,13 +28,14 @@ impl<'fd> PollFd<'fd> {
/// ```no_run
/// # use std::os::unix::io::{AsFd, AsRawFd, FromRawFd};
/// # use nix::{
/// # poll::{PollTimeout, PollFd, PollFlags, poll},
/// # Timeout,
/// # poll::{PollFd, PollFlags, poll},
/// # unistd::{pipe, read}
/// # };
/// let (r, w) = pipe().unwrap();
/// let pfd = PollFd::new(r.as_fd(), PollFlags::POLLIN);
/// let mut fds = [pfd];
/// poll(&mut fds, PollTimeout::NONE).unwrap();
/// poll(&mut fds, Timeout::NONE).unwrap();
/// let mut buf = [0u8; 80];
/// read(r.as_raw_fd(), &mut buf[..]);
/// ```
Expand Down Expand Up @@ -175,229 +177,6 @@ libc_bitflags! {
}
}

/// Timeout argument for [`poll`].
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub struct PollTimeout(i32);

impl PollTimeout {
/// Blocks indefinitely.
///
/// > Specifying a negative value in timeout means an infinite timeout.
pub const NONE: Self = Self(-1);
/// Returns immediately.
///
/// > Specifying a timeout of zero causes poll() to return immediately, even if no file
/// > descriptors are ready.
pub const ZERO: Self = Self(0);
/// Blocks for at most [`std::i32::MAX`] milliseconds.
pub const MAX: Self = Self(i32::MAX);
/// Returns if `self` equals [`PollTimeout::NONE`].
pub fn is_none(&self) -> bool {
// > Specifying a negative value in timeout means an infinite timeout.
*self <= Self::NONE
}
/// Returns if `self` does not equal [`PollTimeout::NONE`].
pub fn is_some(&self) -> bool {
!self.is_none()
}
/// Returns the timeout in milliseconds if there is some, otherwise returns `None`.
pub fn as_millis(&self) -> Option<u32> {
self.is_some().then_some(u32::try_from(self.0).unwrap())
}
/// Returns the timeout as a `Duration` if there is some, otherwise returns `None`.
pub fn timeout(&self) -> Option<Duration> {
self.as_millis()
.map(|x| Duration::from_millis(u64::from(x)))
}
}

/// Error type for integer conversions into `PollTimeout`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PollTimeoutTryFromError {
/// Passing a value less than -1 is invalid on some systems, see
/// <https://man.freebsd.org/cgi/man.cgi?poll#end>.
TooNegative,
/// Passing a value greater than `i32::MAX` is invalid.
TooPositive,
}

impl std::fmt::Display for PollTimeoutTryFromError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooNegative => write!(f, "Passed a negative timeout less than -1."),
Self::TooPositive => write!(f, "Passed a positive timeout greater than `i32::MAX` milliseconds.")
}
}
}

impl std::error::Error for PollTimeoutTryFromError {}

impl<T: Into<PollTimeout>> From<Option<T>> for PollTimeout {
fn from(x: Option<T>) -> Self {
x.map_or(Self::NONE, |x| x.into())
}
}
impl TryFrom<Duration> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: Duration) -> std::result::Result<Self, Self::Error> {
Ok(Self(
i32::try_from(x.as_millis())
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
))
}
}
impl TryFrom<u128> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: u128) -> std::result::Result<Self, Self::Error> {
Ok(Self(
i32::try_from(x)
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
))
}
}
impl TryFrom<u64> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: u64) -> std::result::Result<Self, Self::Error> {
Ok(Self(
i32::try_from(x)
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
))
}
}
impl TryFrom<u32> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: u32) -> std::result::Result<Self, Self::Error> {
Ok(Self(
i32::try_from(x)
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
))
}
}
impl From<u16> for PollTimeout {
fn from(x: u16) -> Self {
Self(i32::from(x))
}
}
impl From<u8> for PollTimeout {
fn from(x: u8) -> Self {
Self(i32::from(x))
}
}
impl TryFrom<i128> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: i128) -> std::result::Result<Self, Self::Error> {
match x {
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
-1.. => Ok(Self(
i32::try_from(x)
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
)),
}
}
}
impl TryFrom<i64> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: i64) -> std::result::Result<Self, Self::Error> {
match x {
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
-1.. => Ok(Self(
i32::try_from(x)
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
)),
}
}
}
impl TryFrom<i32> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: i32) -> std::result::Result<Self, Self::Error> {
match x {
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
-1.. => Ok(Self(x)),
}
}
}
impl TryFrom<i16> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: i16) -> std::result::Result<Self, Self::Error> {
match x {
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
-1.. => Ok(Self(i32::from(x))),
}
}
}
impl TryFrom<i8> for PollTimeout {
type Error = PollTimeoutTryFromError;
fn try_from(x: i8) -> std::result::Result<Self, Self::Error> {
match x {
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
-1.. => Ok(Self(i32::from(x))),
}
}
}
impl TryFrom<PollTimeout> for Duration {
type Error = ();
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
x.timeout().ok_or(())
}
}
impl TryFrom<PollTimeout> for u128 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}
impl TryFrom<PollTimeout> for u64 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}
impl TryFrom<PollTimeout> for u32 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}
impl TryFrom<PollTimeout> for u16 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}
impl TryFrom<PollTimeout> for u8 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}
impl From<PollTimeout> for i128 {
fn from(x: PollTimeout) -> Self {
Self::from(x.0)
}
}
impl From<PollTimeout> for i64 {
fn from(x: PollTimeout) -> Self {
Self::from(x.0)
}
}
impl From<PollTimeout> for i32 {
fn from(x: PollTimeout) -> Self {
x.0
}
}
impl TryFrom<PollTimeout> for i16 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}
impl TryFrom<PollTimeout> for i8 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.0)
}
}

/// `poll` waits for one of a set of file descriptors to become ready to perform I/O.
/// ([`poll(2)`](https://pubs.opengroup.org/onlinepubs/9699919799/functions/poll.html))
///
Expand All @@ -414,11 +193,11 @@ impl TryFrom<PollTimeout> for i8 {
///
/// Note that the timeout interval will be rounded up to the system clock
/// granularity, and kernel scheduling delays mean that the blocking
/// interval may overrun by a small amount. Specifying a [`PollTimeout::NONE`]
/// interval may overrun by a small amount. Specifying a [`Timeout::NONE`]
/// in timeout means an infinite timeout. Specifying a timeout of
/// [`PollTimeout::ZERO`] causes `poll()` to return immediately, even if no file
/// [`Timeout::ZERO`] causes `poll()` to return immediately, even if no file
/// descriptors are ready.
pub fn poll<T: Into<PollTimeout>>(
pub fn poll<T: Into<Timeout>>(
fds: &mut [PollFd],
timeout: T,
) -> Result<libc::c_int> {
Expand Down
14 changes: 8 additions & 6 deletions src/sys/epoll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::errno::Errno;
use crate::Result;
use crate::Timeout;
use libc::{self, c_int};
use std::mem;
use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd};
Expand Down Expand Up @@ -73,11 +74,12 @@ impl EpollEvent {
/// ```
/// # use nix::sys::{epoll::{Epoll, EpollEvent, EpollFlags, EpollCreateFlags}, eventfd::{eventfd, EfdFlags}};
/// # use nix::unistd::write;
/// # use nix::Timeout;
/// # use std::os::unix::io::{OwnedFd, FromRawFd, AsFd};
/// # use std::time::{Instant, Duration};
/// # fn main() -> nix::Result<()> {
/// const DATA: u64 = 17;
/// const MILLIS: u64 = 100;
/// const MILLIS: u8 = 100;
///
/// // Create epoll
/// let epoll = Epoll::new(EpollCreateFlags::empty())?;
Expand All @@ -92,11 +94,11 @@ impl EpollEvent {
///
/// // Wait on event
/// let mut events = [EpollEvent::empty()];
/// epoll.wait(&mut events, MILLIS as isize)?;
/// epoll.wait(&mut events, MILLIS)?;
///
/// // Assert data correct & timeout didn't occur
/// assert_eq!(events[0].data(), DATA);
/// assert!(now.elapsed() < Duration::from_millis(MILLIS));
/// assert!(now.elapsed().as_millis() < MILLIS.into());
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -140,17 +142,17 @@ impl Epoll {
/// (This can be thought of as fetching items from the ready list of the epoll instance.)
///
/// [`epoll_wait`](https://man7.org/linux/man-pages/man2/epoll_wait.2.html)
pub fn wait(
pub fn wait<T: Into<Timeout>>(
&self,
events: &mut [EpollEvent],
timeout: isize,
timeout: T,
) -> Result<usize> {
let res = unsafe {
libc::epoll_wait(
self.0.as_raw_fd(),
events.as_mut_ptr().cast(),
events.len() as c_int,
timeout as c_int,
timeout.into().into(),
)
};

Expand Down
Loading

0 comments on commit 0253dfa

Please sign in to comment.