diff --git a/tokio-epoll-uring/src/system/completion.rs b/tokio-epoll-uring/src/system/completion.rs index 6dcfe19..9189f66 100644 --- a/tokio-epoll-uring/src/system/completion.rs +++ b/tokio-epoll-uring/src/system/completion.rs @@ -642,7 +642,9 @@ mod tests { let (read_task_jh, mut writer) = rt.block_on(async move { let (reader, writer) = os_pipe::pipe().unwrap(); let jh = tokio::spawn(async move { - let system = System::launch_with_testing(Some(testing)).await.unwrap(); + let system = System::launch_with_testing(Some(testing), None) + .await + .unwrap(); let reader = unsafe { OwnedFd::from_raw_fd(nix::unistd::dup(reader.as_raw_fd()).unwrap()) }; let buf = vec![0; 1]; diff --git a/tokio-epoll-uring/src/system/lifecycle.rs b/tokio-epoll-uring/src/system/lifecycle.rs index 20b9da1..ea40cd3 100644 --- a/tokio-epoll-uring/src/system/lifecycle.rs +++ b/tokio-epoll-uring/src/system/lifecycle.rs @@ -21,6 +21,8 @@ use super::{ submission::{SubmitSide, SubmitSideInner, SubmitSideNewArgs}, }; +use slots::SlotsTesting; + /// A running `tokio_epoll_uring` system. Use [`Self::launch`] to start, then [`SystemHandle`] to interact. pub struct System { #[allow(dead_code)] @@ -70,17 +72,19 @@ impl System { /// /// The concept of *poller task* is described in [`crate::doc::design`]. pub async fn launch() -> Result { - Self::launch_with_testing(None).await + Self::launch_with_testing(None, None).await } pub(crate) async fn launch_with_testing( - testing: Option, + poller_testing: Option, + slots_testing: Option, ) -> Result { let id = SYSTEM_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let (submit_side, poller_ready_fut) = { // TODO: should we mlock `slots`? io_uring mmap is mlocked, slots are equally important for the system to function; - let (slots_submit_side, slots_completion_side, slots_poller) = super::slots::new(id); + let (slots_submit_side, slots_completion_side, slots_poller) = + super::slots::new(id, slots_testing.unwrap_or_default()); let uring = Box::new( io_uring::IoUring::builder() @@ -171,7 +175,7 @@ impl System { completion_side, system, slots: slots_poller, - testing, + testing: poller_testing, shutdown_rx, }); (submit_side, poller_ready_fut) diff --git a/tokio-epoll-uring/src/system/slots.rs b/tokio-epoll-uring/src/system/slots.rs index 9d9ad12..ad5a875 100644 --- a/tokio-epoll-uring/src/system/slots.rs +++ b/tokio-epoll-uring/src/system/slots.rs @@ -69,6 +69,30 @@ struct SlotsInner { unused_indices: Vec, co_owner_live: [bool; co_owner::NUM_CO_OWNERS], state: SlotsInnerState, + #[cfg(test)] + testing: SlotsTesting, +} + +#[cfg(test)] +pub(crate) struct SlotsTesting { + pub(crate) test_on_wake: Box< + dyn Send + + Sync + + Fn() -> Option>>, + >, +} + +#[cfg(not(test))] +#[derive(Default)] +pub(crate) struct SlotsTesting; + +#[cfg(test)] +impl Default for SlotsTesting { + fn default() -> Self { + Self { + test_on_wake: Box::new(|| None), + } + } } enum SlotsInnerState { @@ -84,6 +108,9 @@ pub(crate) struct SlotHandle { // FIXME: why is this weak? slots_weak: SlotsWeak, idx: usize, + #[cfg(test)] + test_on_wake: + std::sync::Mutex>>>, } enum Slot { @@ -101,6 +128,7 @@ enum Slot { pub(super) fn new( id: usize, + #[allow(unused_variables)] testing: SlotsTesting, ) -> ( Slots<{ co_owner::SUBMIT_SIDE }>, Slots<{ co_owner::COMPLETION_SIDE }>, @@ -122,6 +150,8 @@ pub(super) fn new( inner_weak: inner_weak.clone(), }, }, + #[cfg(test)] + testing, }) }); fn make_co_owner(inner: &Arc>) -> Slots { @@ -192,6 +222,8 @@ impl SlotsInner { match waiter.send(SlotHandle { slots_weak: myself.clone(), idx, + #[cfg(test)] + test_on_wake: Mutex::new((self.testing.test_on_wake)()), }) { Ok(()) => { trace!("handed `idx` to a waiter"); @@ -373,6 +405,8 @@ impl Slots<{ co_owner::SUBMIT_SIDE }> { SlotHandle { slots_weak: myself.clone(), idx, + #[cfg(test)] + test_on_wake: Mutex::new((inner.testing.test_on_wake)()), } }), None => { @@ -426,6 +460,8 @@ impl SlotHandle { op: O, ) -> (O::Resources, Result>) { let slot = self; + + // invariant: op.is_some() <=> we haven't observed the poll_fn below complete yet let op = std::sync::Mutex::new(Some(op)); // If this future gets dropped _before_ the op completes, we need to make sure @@ -447,7 +483,10 @@ impl SlotHandle { }; let storage = &mut inner.storage; let slot_storage_mut = &mut storage[slot.idx]; - let slot_mut = slot_storage_mut.as_mut().unwrap(); + // the invariant is: `op.is_some() <=> ` + let slot_mut = slot_storage_mut + .as_mut() + .expect("op is Some(), so the poll_fn below hasn't returned the slot yet"); match &mut *slot_mut { Slot::Pending { .. } => { // The resource needs to be kept alive until the op completes. @@ -496,16 +535,15 @@ impl SlotHandle { // Now that we've set up the scope guard, get to business. // Inspect the slot to check whether the poller task already processed the completion. // If it has, good for us. - // If not, set up a oneshot to notify us. (TODO: in the hand-rolled futures, this was simply a std::task::Waker, now it's a oneshot.) - enum InspectSlotResult { - AlreadyDone(i32), - NeedToWait, - ShutDown, - } + // If not, store a waker in the slot so the poller task will wake us up to poll again + // and observe the Slot::Ready then. + // + // If we get cancelled in the meantime (i.e., this future gets dropped), the scopeguard + // will make sure the resources stay alive until the op is complete. let mut poll_count = 0; - let inspect_slot_res = poll_fn(|cx| { + let poll_res = poll_fn(|cx| { poll_count += 1; - let inspect_slot_res = slot.slots_weak.try_upgrade_mut(move |inner| { + let try_upgrade_res = slot.slots_weak.try_upgrade_mut(|inner| { let storage = &mut inner.storage; let slot_storage_ref = &mut storage[slot.idx]; let slot_mut = slot_storage_ref.as_mut().unwrap(); @@ -517,7 +555,7 @@ impl SlotHandle { if !cx.waker().will_wake(waker_mut_ref) { waker.replace(cx.waker().clone()); } - InspectSlotResult::NeedToWait + Poll::Pending } Slot::PendingButFutureDropped { .. } => { unreachable!("if it's dropped, it's not pollable") @@ -526,59 +564,51 @@ impl SlotHandle { trace!("op is ready, returning resources to user"); let res = *res; inner.return_slot(slot.idx); - InspectSlotResult::AlreadyDone(res) + // SAFETY: the slot is ready, so, ownership is back with userspace. + #[allow(unused_unsafe)] + unsafe { + let op = op.lock().unwrap().take().unwrap(); + Poll::Ready(op.on_op_completion(res)) + } } } }); - let inspect_slot_res = match inspect_slot_res { - Err(()) => InspectSlotResult::ShutDown, - Ok(res) => res, - }; - match inspect_slot_res { - InspectSlotResult::NeedToWait => Poll::Pending, - x => Poll::Ready(x), + match try_upgrade_res { + Err(()) => { + // SAFETY: + // This future has an outdated view of the system; it shut down in the meantime. + // Shutdown makes sure that all inflight ops complete, so, + // these resources are no longer owned by the kernel and can be returned as an error. + #[allow(unused_unsafe)] + unsafe { + let op = op.lock().unwrap().take().unwrap(); + Poll::Ready(( + op.on_failed_submission(), + Err(Error::System(SystemError::SystemShuttingDown)), + )) + } + } + Ok(Poll::Ready((resources, res))) => { + Poll::Ready((resources, res.map_err(Error::Op))) + } + Ok(Poll::Pending) => Poll::Pending, } }) .await; assert!(poll_count >= 1); - assert!( - !matches!(inspect_slot_res, InspectSlotResult::NeedToWait), - "poll_fn closure returns Pending in that case" - ); - - let res = match inspect_slot_res { - InspectSlotResult::AlreadyDone(r) => r, - InspectSlotResult::NeedToWait => { - unreachable!() + #[cfg(test)] + { + let on_wake = { slot.test_on_wake.lock().unwrap().take() }; + if let Some(on_wake) = on_wake { + let (tx, rx) = tokio::sync::oneshot::channel(); + on_wake.send(tx).unwrap(); + rx.await.unwrap(); } - InspectSlotResult::ShutDown => { - // SAFETY: - // This future has an outdated view of the system; it shut down in the meantime. - // Shutdown makes sure that all inflight ops complete, so, - // these resources are no longer owned by the kernel and can be returned as an error. - #[allow(unused_unsafe)] - unsafe { - let op = op.lock().unwrap().take().unwrap(); - return ( - op.on_failed_submission(), - Err(Error::System(SystemError::SystemShuttingDown)), - ); - } - } - }; - + } if poll_count == 1 && *crate::env_tunables::YIELD_TO_EXECUTOR_IF_READY_ON_FIRST_POLL { tokio::task::yield_now().await; } - - // SAFETY: - // We got a result, so, kernel is done with the operation and ownership is back with us. - #[allow(unused_unsafe)] - let (resources, res) = unsafe { - let op = op.lock().unwrap().take().expect("we only take() it in drop(), and evidently drop() hasn't happened yet because we're executing a method on self"); - op.on_op_completion(res) - }; - (resources, res.map_err(Error::Op)) + poll_res } } @@ -607,3 +637,43 @@ impl Slot { } } } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use crate::{system::slots::SlotsTesting, System}; + + // Regression-test for issue https://github.com/neondatabase/tokio-epoll-uring/issues/37 + #[tokio::test] + async fn test_wait_for_completion_drop_behavior() { + let (tx, rx) = tokio::sync::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); + let system = System::launch_with_testing( + None, + Some(SlotsTesting { + test_on_wake: Box::new(move || { + Some( + tx.lock() + .unwrap() + .take() + .expect("should only be called once, we only submit one nop here"), + ) + }), + }), + ) + .await + .unwrap(); + let nop = tokio::spawn(system.nop()); + let at_yield_point: tokio::sync::oneshot::Sender<()> = rx.await.unwrap(); + nop.abort(); + let Err(join_err) = nop.await else { + panic!("expecting join error after abort"); + }; + assert!(join_err.is_cancelled()); + assert!( + at_yield_point.is_closed(), + "abort drops the nop op, and hence the oneshot receiver" + ); + } +} diff --git a/tokio-epoll-uring/src/system/test_util/shared_system_handle.rs b/tokio-epoll-uring/src/system/test_util/shared_system_handle.rs index 2b34a15..f04d8bb 100644 --- a/tokio-epoll-uring/src/system/test_util/shared_system_handle.rs +++ b/tokio-epoll-uring/src/system/test_util/shared_system_handle.rs @@ -28,7 +28,7 @@ impl SharedSystemHandle { pub(crate) async fn launch_with_testing( poller_testing: Option, ) -> Result { - let handle = System::launch_with_testing(poller_testing).await?; + let handle = System::launch_with_testing(poller_testing, None).await?; Ok(Self(Arc::new(RwLock::new(Some(handle))))) }