diff --git a/crossbeam-channel/src/channel.rs b/crossbeam-channel/src/channel.rs index 5447e3303..5aafa6c28 100644 --- a/crossbeam-channel/src/channel.rs +++ b/crossbeam-channel/src/channel.rs @@ -14,6 +14,7 @@ use crate::err::{ }; use crate::flavors; use crate::select::{Operation, SelectHandle, Token}; +use crate::waker::BlockingState; /// Creates a multi-producer multi-consumer channel of unbounded capacity. /// @@ -1358,6 +1359,14 @@ impl fmt::Debug for IntoIter { } impl SelectHandle for Sender { + fn start(&self) -> Option> { + match &self.flavor { + SenderFlavor::Array(chan) => chan.sender().start_ref(), + SenderFlavor::List(chan) => chan.sender().start_ref(), + SenderFlavor::Zero(chan) => chan.start(), + } + } + fn try_select(&self, token: &mut Token) -> bool { match &self.flavor { SenderFlavor::Array(chan) => chan.sender().try_select(token), @@ -1370,11 +1379,11 @@ impl SelectHandle for Sender { None } - fn register(&self, oper: Operation, cx: &Context) -> bool { + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { match &self.flavor { - SenderFlavor::Array(chan) => chan.sender().register(oper, cx), - SenderFlavor::List(chan) => chan.sender().register(oper, cx), - SenderFlavor::Zero(chan) => chan.sender().register(oper, cx), + SenderFlavor::Array(chan) => chan.sender().register(oper, cx, state), + SenderFlavor::List(chan) => chan.sender().register(oper, cx, state), + SenderFlavor::Zero(chan) => chan.sender().register(oper, cx, state), } } @@ -1402,11 +1411,11 @@ impl SelectHandle for Sender { } } - fn watch(&self, oper: Operation, cx: &Context) -> bool { + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { match &self.flavor { - SenderFlavor::Array(chan) => chan.sender().watch(oper, cx), - SenderFlavor::List(chan) => chan.sender().watch(oper, cx), - SenderFlavor::Zero(chan) => chan.sender().watch(oper, cx), + SenderFlavor::Array(chan) => chan.sender().watch(oper, cx, state), + SenderFlavor::List(chan) => chan.sender().watch(oper, cx, state), + SenderFlavor::Zero(chan) => chan.sender().watch(oper, cx, state), } } @@ -1420,6 +1429,17 @@ impl SelectHandle for Sender { } impl SelectHandle for Receiver { + fn start(&self) -> Option> { + match &self.flavor { + ReceiverFlavor::Array(chan) => chan.receiver().start_ref(), + ReceiverFlavor::List(chan) => chan.receiver().start_ref(), + ReceiverFlavor::Zero(chan) => chan.start(), + ReceiverFlavor::At(chan) => chan.start(), + ReceiverFlavor::Tick(chan) => chan.start(), + ReceiverFlavor::Never(chan) => chan.start(), + } + } + fn try_select(&self, token: &mut Token) -> bool { match &self.flavor { ReceiverFlavor::Array(chan) => chan.receiver().try_select(token), @@ -1442,14 +1462,14 @@ impl SelectHandle for Receiver { } } - fn register(&self, oper: Operation, cx: &Context) -> bool { + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { match &self.flavor { - ReceiverFlavor::Array(chan) => chan.receiver().register(oper, cx), - ReceiverFlavor::List(chan) => chan.receiver().register(oper, cx), - ReceiverFlavor::Zero(chan) => chan.receiver().register(oper, cx), - ReceiverFlavor::At(chan) => chan.register(oper, cx), - ReceiverFlavor::Tick(chan) => chan.register(oper, cx), - ReceiverFlavor::Never(chan) => chan.register(oper, cx), + ReceiverFlavor::Array(chan) => chan.receiver().register(oper, cx, state), + ReceiverFlavor::List(chan) => chan.receiver().register(oper, cx, state), + ReceiverFlavor::Zero(chan) => chan.receiver().register(oper, cx, state), + ReceiverFlavor::At(chan) => chan.register(oper, cx, state), + ReceiverFlavor::Tick(chan) => chan.register(oper, cx, state), + ReceiverFlavor::Never(chan) => chan.register(oper, cx, state), } } @@ -1486,14 +1506,14 @@ impl SelectHandle for Receiver { } } - fn watch(&self, oper: Operation, cx: &Context) -> bool { + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { match &self.flavor { - ReceiverFlavor::Array(chan) => chan.receiver().watch(oper, cx), - ReceiverFlavor::List(chan) => chan.receiver().watch(oper, cx), - ReceiverFlavor::Zero(chan) => chan.receiver().watch(oper, cx), - ReceiverFlavor::At(chan) => chan.watch(oper, cx), - ReceiverFlavor::Tick(chan) => chan.watch(oper, cx), - ReceiverFlavor::Never(chan) => chan.watch(oper, cx), + ReceiverFlavor::Array(chan) => chan.receiver().watch(oper, cx, state), + ReceiverFlavor::List(chan) => chan.receiver().watch(oper, cx, state), + ReceiverFlavor::Zero(chan) => chan.receiver().watch(oper, cx, state), + ReceiverFlavor::At(chan) => chan.watch(oper, cx, state), + ReceiverFlavor::Tick(chan) => chan.watch(oper, cx, state), + ReceiverFlavor::Never(chan) => chan.watch(oper, cx, state), } } diff --git a/crossbeam-channel/src/flavors/array.rs b/crossbeam-channel/src/flavors/array.rs index 206a05a86..dafe0aaad 100644 --- a/crossbeam-channel/src/flavors/array.rs +++ b/crossbeam-channel/src/flavors/array.rs @@ -20,7 +20,7 @@ use crossbeam_utils::{Backoff, CachePadded}; use crate::context::Context; use crate::err::{RecvTimeoutError, SendTimeoutError, TryRecvError, TrySendError}; use crate::select::{Operation, SelectHandle, Selected, Token}; -use crate::waker::SyncWaker; +use crate::waker::{BlockingState, SyncWaker}; /// A slot in a channel. struct Slot { @@ -87,6 +87,18 @@ pub(crate) struct Channel { receivers: SyncWaker, } +/// The state of the channel after calling `start_recv` or `start_send`. +#[derive(PartialEq, Eq)] +enum Status { + /// The channel is ready to read or write to. + Ready, + /// There is currently a send or receive in progress holding up the queue. + /// All operations must block to preserve linearizability. + InProgress, + /// The channel is empty. + Empty, +} + impl Channel { /// Creates a bounded channel of capacity `cap`. pub(crate) fn with_capacity(cap: usize) -> Self { @@ -135,7 +147,7 @@ impl Channel { } /// Attempts to reserve a slot for sending a message. - fn start_send(&self, token: &mut Token) -> bool { + fn start_send(&self, token: &mut Token) -> Status { let backoff = Backoff::new(); let mut tail = self.tail.load(Ordering::Relaxed); @@ -144,7 +156,7 @@ impl Channel { if tail & self.mark_bit != 0 { token.array.slot = ptr::null(); token.array.stamp = 0; - return true; + return Status::Ready; } // Deconstruct the tail. @@ -179,7 +191,7 @@ impl Channel { // Prepare the token for the follow-up call to `write`. token.array.slot = slot as *const Slot as *const u8; token.array.stamp = tail + 1; - return true; + return Status::Ready; } Err(t) => { tail = t; @@ -193,7 +205,14 @@ impl Channel { // If the head lags one lap behind the tail as well... if head.wrapping_add(self.one_lap) == tail { // ...then the channel is full. - return false; + return Status::Empty; + } + + // The head was advanced but the stamp hasn't been updated yet, + // meaning a receive is in-progress. Spin for a bit waiting for + // the receive to complete before falling back to blocking. + if backoff.is_completed() { + return Status::InProgress; } backoff.spin(); @@ -225,7 +244,7 @@ impl Channel { } /// Attempts to reserve a slot for receiving a message. - fn start_recv(&self, token: &mut Token) -> bool { + fn start_recv(&self, token: &mut Token) -> Status { let backoff = Backoff::new(); let mut head = self.head.load(Ordering::Relaxed); @@ -262,7 +281,7 @@ impl Channel { // Prepare the token for the follow-up call to `read`. token.array.slot = slot as *const Slot as *const u8; token.array.stamp = head.wrapping_add(self.one_lap); - return true; + return Status::Ready; } Err(h) => { head = h; @@ -280,13 +299,20 @@ impl Channel { // ...then receive an error. token.array.slot = ptr::null(); token.array.stamp = 0; - return true; + return Status::Ready; } else { // Otherwise, the receive operation is not ready. - return false; + return Status::Empty; } } + // The tail was advanced but the stamp hasn't been updated yet, + // meaning a send is in-progress. Spin for a bit waiting for + // the send to complete before falling back to blocking. + if backoff.is_completed() { + return Status::InProgress; + } + backoff.spin(); head = self.head.load(Ordering::Relaxed); } else { @@ -317,11 +343,13 @@ impl Channel { /// Attempts to send a message into the channel. pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError> { - let token = &mut Token::default(); - if self.start_send(token) { - unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) } - } else { - Err(TrySendError::Full(msg)) + match self.send_blocking(msg, None, false) { + Ok(None) => Ok(()), + Ok(Some(msg)) => Err(TrySendError::Full(msg)), + Err(SendTimeoutError::Disconnected(msg)) => Err(TrySendError::Disconnected(msg)), + Err(SendTimeoutError::Timeout(_)) => { + unreachable!("called recv_blocking with deadline: None") + } } } @@ -331,14 +359,30 @@ impl Channel { msg: T, deadline: Option, ) -> Result<(), SendTimeoutError> { + self.send_blocking(msg, deadline, true) + .map(|value| assert!(value.is_none(), "called send_blocking with block: true")) + } + + /// Sends a message into the channel. + pub(crate) fn send_blocking( + &self, + msg: T, + deadline: Option, + block: bool, + ) -> Result, SendTimeoutError> { let token = &mut Token::default(); + let mut state = self.senders.start(); loop { // Try sending a message several times. let backoff = Backoff::new(); loop { - if self.start_send(token) { - let res = unsafe { self.write(token, msg) }; - return res.map_err(SendTimeoutError::Disconnected); + match self.start_send(token) { + Status::Ready => { + let res = unsafe { self.write(token, msg) }; + return res.map(|_| None).map_err(SendTimeoutError::Disconnected); + } + Status::Empty if !block => return Ok(Some(msg)), + _ => {} } if backoff.is_completed() { @@ -357,7 +401,7 @@ impl Channel { Context::with(|cx| { // Prepare for blocking until a receiver wakes us up. let oper = Operation::hook(token); - self.senders.register(oper, cx); + self.senders.register(oper, cx, &state); // Has the channel become ready just now? if !self.is_full() || self.is_disconnected() { @@ -375,30 +419,47 @@ impl Channel { Selected::Operation(_) => {} } }); + + state.unpark(); } } /// Attempts to receive a message without blocking. pub(crate) fn try_recv(&self) -> Result { - let token = &mut Token::default(); - - if self.start_recv(token) { - unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) } - } else { - Err(TryRecvError::Empty) + match self.recv_blocking(None, false) { + Ok(Some(value)) => Ok(value), + Ok(None) => Err(TryRecvError::Empty), + Err(RecvTimeoutError::Disconnected) => Err(TryRecvError::Disconnected), + Err(RecvTimeoutError::Timeout) => { + unreachable!("called recv_blocking with deadline: None") + } } } /// Receives a message from the channel. pub(crate) fn recv(&self, deadline: Option) -> Result { + self.recv_blocking(deadline, true) + .map(|value| value.expect("called recv_blocking with block: true")) + } + + pub(crate) fn recv_blocking( + &self, + deadline: Option, + block: bool, + ) -> Result, RecvTimeoutError> { let token = &mut Token::default(); + let mut state = self.receivers.start(); loop { // Try receiving a message several times. let backoff = Backoff::new(); loop { - if self.start_recv(token) { - let res = unsafe { self.read(token) }; - return res.map_err(|_| RecvTimeoutError::Disconnected); + match self.start_recv(token) { + Status::Ready => { + let res = unsafe { self.read(token) }; + return res.map(Some).map_err(|_| RecvTimeoutError::Disconnected); + } + Status::Empty if !block => return Ok(None), + _ => {} } if backoff.is_completed() { @@ -417,7 +478,7 @@ impl Channel { Context::with(|cx| { // Prepare for blocking until a sender wakes us up. let oper = Operation::hook(token); - self.receivers.register(oper, cx); + self.receivers.register(oper, cx, &state); // Has the channel become ready just now? if !self.is_empty() || self.is_disconnected() { @@ -437,6 +498,8 @@ impl Channel { Selected::Operation(_) => {} } }); + + state.unpark(); } } @@ -566,17 +629,29 @@ pub(crate) struct Receiver<'a, T>(&'a Channel); /// Sender handle to a channel. pub(crate) struct Sender<'a, T>(&'a Channel); +impl<'a, T> Receiver<'a, T> { + /// Same as `SelectHandle::start`, but with a more specific lifetime. + pub(crate) fn start_ref(&self) -> Option> { + Some(self.0.receivers.start()) + } +} + impl SelectHandle for Receiver<'_, T> { + fn start(&self) -> Option> { + self.start_ref() + } + fn try_select(&self, token: &mut Token) -> bool { - self.0.start_recv(token) + self.0.start_recv(token) == Status::Ready } fn deadline(&self) -> Option { None } - fn register(&self, oper: Operation, cx: &Context) -> bool { - self.0.receivers.register(oper, cx); + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + let state = state.expect("Receiver::start returns blocking state"); + self.0.receivers.register(oper, cx, state); self.is_ready() } @@ -592,8 +667,9 @@ impl SelectHandle for Receiver<'_, T> { !self.0.is_empty() || self.0.is_disconnected() } - fn watch(&self, oper: Operation, cx: &Context) -> bool { - self.0.receivers.watch(oper, cx); + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + let state = state.expect("Receiver::start returns blocking state"); + self.0.receivers.watch(oper, cx, state); self.is_ready() } @@ -602,17 +678,29 @@ impl SelectHandle for Receiver<'_, T> { } } +impl<'a, T> Sender<'a, T> { + /// Same as `SelectHandle::start`, but with a more specific lifetime. + pub(crate) fn start_ref(&self) -> Option> { + Some(self.0.senders.start()) + } +} + impl SelectHandle for Sender<'_, T> { + fn start(&self) -> Option> { + self.start_ref() + } + fn try_select(&self, token: &mut Token) -> bool { - self.0.start_send(token) + self.0.start_send(token) == Status::Ready } fn deadline(&self) -> Option { None } - fn register(&self, oper: Operation, cx: &Context) -> bool { - self.0.senders.register(oper, cx); + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + let state = state.expect("Sender::start returns blocking state"); + self.0.senders.register(oper, cx, state); self.is_ready() } @@ -628,8 +716,9 @@ impl SelectHandle for Sender<'_, T> { !self.0.is_full() || self.0.is_disconnected() } - fn watch(&self, oper: Operation, cx: &Context) -> bool { - self.0.senders.watch(oper, cx); + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + let state = state.expect("Sender::start returns blocking state"); + self.0.senders.watch(oper, cx, state); self.is_ready() } diff --git a/crossbeam-channel/src/flavors/at.rs b/crossbeam-channel/src/flavors/at.rs index 83e69c1ed..177d5d1a7 100644 --- a/crossbeam-channel/src/flavors/at.rs +++ b/crossbeam-channel/src/flavors/at.rs @@ -10,6 +10,7 @@ use crate::context::Context; use crate::err::{RecvTimeoutError, TryRecvError}; use crate::select::{Operation, SelectHandle, Token}; use crate::utils; +use crate::waker::BlockingState; /// Result of a receive operation. pub(crate) type AtToken = Option; @@ -140,6 +141,10 @@ impl Channel { } impl SelectHandle for Channel { + fn start(&self) -> Option> { + None + } + #[inline] fn try_select(&self, token: &mut Token) -> bool { match self.try_recv() { @@ -166,7 +171,12 @@ impl SelectHandle for Channel { } #[inline] - fn register(&self, _oper: Operation, _cx: &Context) -> bool { + fn register( + &self, + _oper: Operation, + _cx: &Context, + _state: Option<&BlockingState<'_>>, + ) -> bool { self.is_ready() } @@ -184,7 +194,7 @@ impl SelectHandle for Channel { } #[inline] - fn watch(&self, _oper: Operation, _cx: &Context) -> bool { + fn watch(&self, _oper: Operation, _cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { self.is_ready() } diff --git a/crossbeam-channel/src/flavors/list.rs b/crossbeam-channel/src/flavors/list.rs index e86551ad2..058958e17 100644 --- a/crossbeam-channel/src/flavors/list.rs +++ b/crossbeam-channel/src/flavors/list.rs @@ -4,7 +4,7 @@ use std::boxed::Box; use std::cell::UnsafeCell; use std::marker::PhantomData; use std::mem::MaybeUninit; -use std::ptr; +use std::ptr::{self, NonNull}; use std::sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering}; use std::time::Instant; @@ -13,7 +13,7 @@ use crossbeam_utils::{Backoff, CachePadded}; use crate::context::Context; use crate::err::{RecvTimeoutError, SendTimeoutError, TryRecvError, TrySendError}; use crate::select::{Operation, SelectHandle, Selected, Token}; -use crate::waker::SyncWaker; +use crate::waker::{BlockingState, SyncWaker}; // TODO(stjepang): Once we bump the minimum required Rust version to 1.28 or newer, re-apply the // following changes by @kleimkuhler: @@ -56,10 +56,48 @@ impl Slot { }; /// Waits until a message is written into the slot. - fn wait_write(&self) { - let backoff = Backoff::new(); - while self.state.load(Ordering::Acquire) & WRITE == 0 { - backoff.snooze(); + fn wait_write(&self, receivers: &SyncWaker, token: &mut Token) { + let mut state = receivers.start(); + + loop { + // Try reading the message several times. + let backoff = Backoff::new(); + loop { + if self.state.load(Ordering::Acquire) & WRITE != 0 { + return; + } + + if backoff.is_completed() { + break; + } else { + backoff.snooze(); + } + } + + // Prepare for blocking until a sender wakes us up. + Context::with(|cx| { + let oper = Operation::hook(token); + // Register to be notified after any message is sent. + receivers.watch(oper, cx, &state); + + // Was the emssage just sent? + if self.state.load(Ordering::Acquire) & WRITE != 0 { + let _ = cx.try_select(Selected::Aborted); + } + + // Block the current thread. + let sel = cx.wait_until(None); + + match sel { + Selected::Waiting => unreachable!(), + Selected::Aborted | Selected::Disconnected => { + receivers.unwatch(oper); + } + Selected::Operation(_) => {} + } + + state.unpark(); + }); } } } @@ -85,14 +123,47 @@ impl Block { } /// Waits until the next pointer is set. - fn wait_next(&self) -> *mut Self { - let backoff = Backoff::new(); + fn wait_next(&self, receivers: &SyncWaker, token: &mut Token) -> *mut Self { + let mut state = receivers.start(); loop { - let next = self.next.load(Ordering::Acquire); - if !next.is_null() { - return next; + // Try reading the message several times. + let backoff = Backoff::new(); + loop { + if let Some(next) = NonNull::new(self.next.load(Ordering::Acquire)) { + return next.as_ptr(); + } + + if backoff.is_completed() { + break; + } else { + backoff.snooze(); + } } - backoff.snooze(); + + // Prepare for blocking until a sender wakes us up. + Context::with(|cx| { + let oper = Operation::hook(token); + // Register to be notified after any message is sent. + receivers.watch(oper, cx, &state); + + // Was the next pointer just written? + if !self.next.load(Ordering::Acquire).is_null() { + let _ = cx.try_select(Selected::Aborted); + } + + // Block the current thread. + let sel = cx.wait_until(None); + + match sel { + Selected::Waiting => unreachable!(), + Selected::Aborted | Selected::Disconnected => { + receivers.unwatch(oper); + } + Selected::Operation(_) => {} + } + + state.unpark(); + }); } } @@ -168,6 +239,18 @@ pub(crate) struct Channel { _marker: PhantomData, } +/// The status of the channel after calling `start_recv`. +#[derive(PartialEq, Eq)] +enum Status { + /// The channel has a message ready to read. + Ready, + /// There is currently a send in progress holding up the queue. + /// Both `recv` and `try_recv` must block to preserve linearizability. + InProgress, + /// The channel is empty. + Empty, +} + impl Channel { /// Creates a new unbounded channel. pub(crate) fn new() -> Self { @@ -196,7 +279,7 @@ impl Channel { } /// Attempts to reserve a slot for sending a message. - fn start_send(&self, token: &mut Token) -> bool { + fn start_send(&self, token: &mut Token) -> Status { let backoff = Backoff::new(); let mut tail = self.tail.index.load(Ordering::Acquire); let mut block = self.tail.block.load(Ordering::Acquire); @@ -206,14 +289,19 @@ impl Channel { // Check if the channel is disconnected. if tail & MARK_BIT != 0 { token.list.block = ptr::null(); - return true; + return Status::Ready; } // Calculate the offset of the index into the block. let offset = (tail >> SHIFT) % LAP; // If we reached the end of the block, wait until the next one is installed. + // If we've been waiting for too long, fallback to blocking. if offset == BLOCK_CAP { + if backoff.is_completed() { + return Status::InProgress; + } + backoff.snooze(); tail = self.tail.index.load(Ordering::Acquire); block = self.tail.block.load(Ordering::Acquire); @@ -267,7 +355,7 @@ impl Channel { token.list.block = block as *const u8; token.list.offset = offset; - return true; + return Status::Ready; }, Err(t) => { tail = t; @@ -298,7 +386,7 @@ impl Channel { } /// Attempts to reserve a slot for receiving a message. - fn start_recv(&self, token: &mut Token) -> bool { + fn start_recv(&self, token: &mut Token) -> Status { let backoff = Backoff::new(); let mut head = self.head.index.load(Ordering::Acquire); let mut block = self.head.block.load(Ordering::Acquire); @@ -307,8 +395,14 @@ impl Channel { // Calculate the offset of the index into the block. let offset = (head >> SHIFT) % LAP; - // If we reached the end of the block, wait until the next one is installed. + // We reached the end of the block but the block is not installed yet, meaning + // the last send on the previous block is still in progress. The send is likely to + // be soon so we spin here before falling back to blocking. if offset == BLOCK_CAP { + if backoff.is_completed() { + return Status::InProgress; + } + backoff.snooze(); head = self.head.index.load(Ordering::Acquire); block = self.head.block.load(Ordering::Acquire); @@ -327,10 +421,10 @@ impl Channel { if tail & MARK_BIT != 0 { // ...then receive an error. token.list.block = ptr::null(); - return true; + return Status::Ready; } else { // Otherwise, the receive operation is not ready. - return false; + return Status::Empty; } } @@ -340,9 +434,14 @@ impl Channel { } } - // The block can be null here only if the first message is being sent into the channel. - // In that case, just wait until it gets initialized. + // The block can be null here only if the first message sent into the channel is + // in progress. The send is likely to complete soon so we spin here before falling + // back to blocking. if block.is_null() { + if backoff.is_completed() { + return Status::InProgress; + } + backoff.snooze(); head = self.head.index.load(Ordering::Acquire); block = self.head.block.load(Ordering::Acquire); @@ -359,7 +458,7 @@ impl Channel { Ok(_) => unsafe { // If we've reached the end of the block, move to the next one. if offset + 1 == BLOCK_CAP { - let next = (*block).wait_next(); + let next = (*block).wait_next(&self.receivers, token); let mut next_index = (new_head & !MARK_BIT).wrapping_add(1 << SHIFT); if !(*next).next.load(Ordering::Relaxed).is_null() { next_index |= MARK_BIT; @@ -371,7 +470,7 @@ impl Channel { token.list.block = block as *const u8; token.list.offset = offset; - return true; + return Status::Ready; }, Err(h) => { head = h; @@ -393,7 +492,7 @@ impl Channel { let block = token.list.block as *mut Block; let offset = token.list.offset; let slot = unsafe { (*block).slots.get_unchecked(offset) }; - slot.wait_write(); + slot.wait_write(&self.receivers, token); let msg = unsafe { slot.msg.get().read().assume_init() }; // Destroy the block if we've reached the end, or if another thread wanted to destroy but @@ -424,35 +523,96 @@ impl Channel { _deadline: Option, ) -> Result<(), SendTimeoutError> { let token = &mut Token::default(); - assert!(self.start_send(token)); - unsafe { - self.write(token, msg) - .map_err(SendTimeoutError::Disconnected) + + // It's possible that we can't proceed because of the sender that + // is supposed to install the next block lagging, so we might have to + // block for a message to be sent. + let mut state = self.receivers.start(); + let mut succeeded = false; + loop { + // Try sending a message several times. + let backoff = Backoff::new(); + loop { + if succeeded || self.start_send(token) == Status::Ready { + return unsafe { + self.write(token, msg) + .map_err(SendTimeoutError::Disconnected) + }; + } + + if backoff.is_completed() { + break; + } else { + backoff.snooze(); + } + } + + // Prepare for blocking until a sender wakes us up. + Context::with(|cx| { + let oper = Operation::hook(token); + // Register to be notified after any message is sent. + self.receivers.watch(oper, cx, &state); + + // Has the channel become ready just now? + if self.start_send(token) == Status::Ready { + succeeded = true; + let _ = cx.try_select(Selected::Aborted); + } + + // Block the current thread. + let sel = cx.wait_until(None); + + match sel { + Selected::Waiting => unreachable!(), + Selected::Aborted | Selected::Disconnected => { + self.receivers.unwatch(oper); + } + Selected::Operation(_) => {} + } + + state.unpark(); + }); } } /// Attempts to receive a message without blocking. pub(crate) fn try_recv(&self) -> Result { - let token = &mut Token::default(); - - if self.start_recv(token) { - unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) } - } else { - Err(TryRecvError::Empty) + match self.recv_blocking(None, false) { + Ok(Some(value)) => Ok(value), + Ok(None) => Err(TryRecvError::Empty), + Err(RecvTimeoutError::Disconnected) => Err(TryRecvError::Disconnected), + Err(RecvTimeoutError::Timeout) => { + unreachable!("called recv_blocking with deadline: None") + } } } /// Receives a message from the channel. pub(crate) fn recv(&self, deadline: Option) -> Result { + self.recv_blocking(deadline, true) + .map(|value| value.expect("called recv_blocking with block: true")) + } + + /// Receives a message from the channel. + pub(crate) fn recv_blocking( + &self, + deadline: Option, + block: bool, + ) -> Result, RecvTimeoutError> { let token = &mut Token::default(); + + let mut state = self.receivers.start(); loop { // Try receiving a message several times. let backoff = Backoff::new(); loop { - if self.start_recv(token) { - unsafe { - return self.read(token).map_err(|_| RecvTimeoutError::Disconnected); + match self.start_recv(token) { + Status::Ready => { + let res = unsafe { self.read(token) }; + return res.map(Some).map_err(|_| RecvTimeoutError::Disconnected); } + Status::Empty if !block => return Ok(None), + _ => {} } if backoff.is_completed() { @@ -471,7 +631,7 @@ impl Channel { // Prepare for blocking until a sender wakes us up. Context::with(|cx| { let oper = Operation::hook(token); - self.receivers.register(oper, cx); + self.receivers.register(oper, cx, &state); // Has the channel become ready just now? if !self.is_empty() || self.is_disconnected() { @@ -490,6 +650,8 @@ impl Channel { } Selected::Operation(_) => {} } + + state.unpark(); }); } } @@ -569,6 +731,7 @@ impl Channel { /// /// This method should only be called when all receivers are dropped. fn discard_all_messages(&self) { + let token = &mut Token::default(); let backoff = Backoff::new(); let mut tail = self.tail.index.load(Ordering::Acquire); loop { @@ -610,10 +773,10 @@ impl Channel { if offset < BLOCK_CAP { // Drop the message in the slot. let slot = (*block).slots.get_unchecked(offset); - slot.wait_write(); + slot.wait_write(&self.receivers, token); (*slot.msg.get()).assume_init_drop(); } else { - (*block).wait_next(); + (*block).wait_next(&self.receivers, token); // Deallocate the block and move to the next one. let next = (*block).next.load(Ordering::Acquire); drop(Box::from_raw(block)); @@ -693,17 +856,29 @@ pub(crate) struct Receiver<'a, T>(&'a Channel); /// Sender handle to a channel. pub(crate) struct Sender<'a, T>(&'a Channel); +impl<'a, T> Receiver<'a, T> { + /// Same as `SelectHandle::start`, but with a more specific lifetime. + pub(crate) fn start_ref(&self) -> Option> { + Some(self.0.receivers.start()) + } +} + impl SelectHandle for Receiver<'_, T> { + fn start(&self) -> Option> { + self.start_ref() + } + fn try_select(&self, token: &mut Token) -> bool { - self.0.start_recv(token) + self.0.start_recv(token) == Status::Ready } fn deadline(&self) -> Option { None } - fn register(&self, oper: Operation, cx: &Context) -> bool { - self.0.receivers.register(oper, cx); + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + let state = state.expect("Receiver::start returns blocking state"); + self.0.receivers.register(oper, cx, state); self.is_ready() } @@ -719,8 +894,9 @@ impl SelectHandle for Receiver<'_, T> { !self.0.is_empty() || self.0.is_disconnected() } - fn watch(&self, oper: Operation, cx: &Context) -> bool { - self.0.receivers.watch(oper, cx); + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + let state = state.expect("Receiver::start returns blocking state"); + self.0.receivers.watch(oper, cx, state); self.is_ready() } @@ -729,16 +905,32 @@ impl SelectHandle for Receiver<'_, T> { } } -impl SelectHandle for Sender<'_, T> { +impl<'a, T> Sender<'a, T> { + /// Same as `SelectHandle::start`, but with a more specific lifetime. + pub(crate) fn start_ref(&self) -> Option> { + None + } +} + +impl<'a, T> SelectHandle for Sender<'a, T> { + fn start(&self) -> Option> { + None + } + fn try_select(&self, token: &mut Token) -> bool { - self.0.start_send(token) + self.0.start_send(token) == Status::Ready } fn deadline(&self) -> Option { None } - fn register(&self, _oper: Operation, _cx: &Context) -> bool { + fn register( + &self, + _oper: Operation, + _cx: &Context, + _state: Option<&BlockingState<'_>>, + ) -> bool { self.is_ready() } @@ -752,7 +944,7 @@ impl SelectHandle for Sender<'_, T> { true } - fn watch(&self, _oper: Operation, _cx: &Context) -> bool { + fn watch(&self, _oper: Operation, _cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { self.is_ready() } diff --git a/crossbeam-channel/src/flavors/never.rs b/crossbeam-channel/src/flavors/never.rs index 7a9f830ac..0a985c548 100644 --- a/crossbeam-channel/src/flavors/never.rs +++ b/crossbeam-channel/src/flavors/never.rs @@ -9,6 +9,7 @@ use crate::context::Context; use crate::err::{RecvTimeoutError, TryRecvError}; use crate::select::{Operation, SelectHandle, Token}; use crate::utils; +use crate::waker::BlockingState; /// This flavor doesn't need a token. pub(crate) type NeverToken = (); @@ -72,6 +73,10 @@ impl Channel { } impl SelectHandle for Channel { + fn start(&self) -> Option> { + None + } + #[inline] fn try_select(&self, _token: &mut Token) -> bool { false @@ -83,7 +88,12 @@ impl SelectHandle for Channel { } #[inline] - fn register(&self, _oper: Operation, _cx: &Context) -> bool { + fn register( + &self, + _oper: Operation, + _cx: &Context, + _state: Option<&BlockingState<'_>>, + ) -> bool { self.is_ready() } @@ -101,7 +111,7 @@ impl SelectHandle for Channel { } #[inline] - fn watch(&self, _oper: Operation, _cx: &Context) -> bool { + fn watch(&self, _oper: Operation, _cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { self.is_ready() } diff --git a/crossbeam-channel/src/flavors/tick.rs b/crossbeam-channel/src/flavors/tick.rs index a5b67ed9e..e9feacfee 100644 --- a/crossbeam-channel/src/flavors/tick.rs +++ b/crossbeam-channel/src/flavors/tick.rs @@ -10,6 +10,7 @@ use crossbeam_utils::atomic::AtomicCell; use crate::context::Context; use crate::err::{RecvTimeoutError, TryRecvError}; use crate::select::{Operation, SelectHandle, Token}; +use crate::waker::BlockingState; /// Result of a receive operation. pub(crate) type TickToken = Option; @@ -115,6 +116,10 @@ impl Channel { } impl SelectHandle for Channel { + fn start(&self) -> Option> { + None + } + #[inline] fn try_select(&self, token: &mut Token) -> bool { match self.try_recv() { @@ -136,7 +141,12 @@ impl SelectHandle for Channel { } #[inline] - fn register(&self, _oper: Operation, _cx: &Context) -> bool { + fn register( + &self, + _oper: Operation, + _cx: &Context, + _state: Option<&BlockingState<'_>>, + ) -> bool { self.is_ready() } @@ -154,7 +164,7 @@ impl SelectHandle for Channel { } #[inline] - fn watch(&self, _oper: Operation, _cx: &Context) -> bool { + fn watch(&self, _oper: Operation, _cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { self.is_ready() } diff --git a/crossbeam-channel/src/flavors/zero.rs b/crossbeam-channel/src/flavors/zero.rs index 08d226f87..dfa9f089f 100644 --- a/crossbeam-channel/src/flavors/zero.rs +++ b/crossbeam-channel/src/flavors/zero.rs @@ -15,7 +15,7 @@ use crossbeam_utils::Backoff; use crate::context::Context; use crate::err::{RecvTimeoutError, SendTimeoutError, TryRecvError, TrySendError}; use crate::select::{Operation, SelectHandle, Selected, Token}; -use crate::waker::Waker; +use crate::waker::{BlockingState, Waker}; /// A pointer to a packet. pub(crate) struct ZeroToken(*mut ()); @@ -387,6 +387,10 @@ impl Channel { pub(crate) fn is_full(&self) -> bool { true } + + pub(crate) fn start(&self) -> Option> { + None + } } /// Receiver handle to a channel. @@ -396,6 +400,10 @@ pub(crate) struct Receiver<'a, T>(&'a Channel); pub(crate) struct Sender<'a, T>(&'a Channel); impl SelectHandle for Receiver<'_, T> { + fn start(&self) -> Option> { + None + } + fn try_select(&self, token: &mut Token) -> bool { self.0.start_recv(token) } @@ -404,7 +412,7 @@ impl SelectHandle for Receiver<'_, T> { None } - fn register(&self, oper: Operation, cx: &Context) -> bool { + fn register(&self, oper: Operation, cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { let packet = Box::into_raw(Packet::::empty_on_heap()); let mut inner = self.0.inner.lock().unwrap(); @@ -433,7 +441,7 @@ impl SelectHandle for Receiver<'_, T> { inner.senders.can_select() || inner.is_disconnected } - fn watch(&self, oper: Operation, cx: &Context) -> bool { + fn watch(&self, oper: Operation, cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { let mut inner = self.0.inner.lock().unwrap(); inner.receivers.watch(oper, cx); inner.senders.can_select() || inner.is_disconnected @@ -445,7 +453,11 @@ impl SelectHandle for Receiver<'_, T> { } } -impl SelectHandle for Sender<'_, T> { +impl<'a, T> SelectHandle for Sender<'a, T> { + fn start(&self) -> Option> { + None + } + fn try_select(&self, token: &mut Token) -> bool { self.0.start_send(token) } @@ -454,7 +466,7 @@ impl SelectHandle for Sender<'_, T> { None } - fn register(&self, oper: Operation, cx: &Context) -> bool { + fn register(&self, oper: Operation, cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { let packet = Box::into_raw(Packet::::empty_on_heap()); let mut inner = self.0.inner.lock().unwrap(); @@ -483,7 +495,7 @@ impl SelectHandle for Sender<'_, T> { inner.receivers.can_select() || inner.is_disconnected } - fn watch(&self, oper: Operation, cx: &Context) -> bool { + fn watch(&self, oper: Operation, cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { let mut inner = self.0.inner.lock().unwrap(); inner.senders.watch(oper, cx); inner.receivers.can_select() || inner.is_disconnected diff --git a/crossbeam-channel/src/lib.rs b/crossbeam-channel/src/lib.rs index 35876c160..7faaa732b 100644 --- a/crossbeam-channel/src/lib.rs +++ b/crossbeam-channel/src/lib.rs @@ -362,6 +362,7 @@ mod waker; #[cfg(feature = "std")] pub mod internal { pub use crate::select::{select, select_timeout, try_select, SelectHandle}; + pub use crate::waker::BlockingState; } #[cfg(feature = "std")] diff --git a/crossbeam-channel/src/select.rs b/crossbeam-channel/src/select.rs index ac9e408d3..52df580cc 100644 --- a/crossbeam-channel/src/select.rs +++ b/crossbeam-channel/src/select.rs @@ -15,6 +15,7 @@ use crate::err::{RecvError, SendError}; use crate::err::{SelectTimeoutError, TrySelectError}; use crate::flavors; use crate::utils; +use crate::waker::BlockingState; /// Temporary data that gets initialized during select or a blocking operation, and is consumed by /// `read` or `write`. @@ -98,6 +99,9 @@ impl From for usize { /// appropriate deadline for blocking, etc. // This is a private API (exposed inside crossbeam_channel::internal module) that is used by the select macro. pub trait SelectHandle { + /// Returns a guard that manages the state of the operation. + fn start(&self) -> Option>; + /// Attempts to select an operation and returns `true` on success. fn try_select(&self, token: &mut Token) -> bool; @@ -105,7 +109,7 @@ pub trait SelectHandle { fn deadline(&self) -> Option; /// Registers an operation for execution and returns `true` if it is now ready. - fn register(&self, oper: Operation, cx: &Context) -> bool; + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool; /// Unregisters an operation for execution. fn unregister(&self, oper: Operation); @@ -117,13 +121,17 @@ pub trait SelectHandle { fn is_ready(&self) -> bool; /// Registers an operation for readiness notification and returns `true` if it is now ready. - fn watch(&self, oper: Operation, cx: &Context) -> bool; + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool; /// Unregisters an operation for readiness notification. fn unwatch(&self, oper: Operation); } impl SelectHandle for &T { + fn start(&self) -> Option> { + (**self).start() + } + fn try_select(&self, token: &mut Token) -> bool { (**self).try_select(token) } @@ -132,8 +140,8 @@ impl SelectHandle for &T { (**self).deadline() } - fn register(&self, oper: Operation, cx: &Context) -> bool { - (**self).register(oper, cx) + fn register(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + (**self).register(oper, cx, state) } fn unregister(&self, oper: Operation) { @@ -148,8 +156,8 @@ impl SelectHandle for &T { (**self).is_ready() } - fn watch(&self, oper: Operation, cx: &Context) -> bool { - (**self).watch(oper, cx) + fn watch(&self, oper: Operation, cx: &Context, state: Option<&BlockingState<'_>>) -> bool { + (**self).watch(oper, cx, state) } fn unwatch(&self, oper: Operation) { @@ -157,6 +165,46 @@ impl SelectHandle for &T { } } +// A dummy handle used to initialize the handles array by the select! macro. +impl SelectHandle for () { + fn start(&self) -> Option> { + None + } + + fn try_select(&self, _token: &mut Token) -> bool { + false + } + + fn deadline(&self) -> Option { + None + } + + fn register( + &self, + _oper: Operation, + _cx: &Context, + _state: Option<&BlockingState<'_>>, + ) -> bool { + false + } + + fn unregister(&self, _oper: Operation) {} + + fn accept(&self, _token: &mut Token, _cx: &Context) -> bool { + false + } + + fn is_ready(&self) -> bool { + false + } + + fn watch(&self, _oper: Operation, _cx: &Context, _state: Option<&BlockingState<'_>>) -> bool { + false + } + + fn unwatch(&self, _oper: Operation) {} +} + /// Determines when a select operation should time out. #[derive(Clone, Copy, Eq, PartialEq)] enum Timeout { @@ -175,7 +223,12 @@ enum Timeout { /// Successful receive operations will have to be followed up by `channel::read()` and successful /// send operations by `channel::write()`. fn run_select( - handles: &mut [(&dyn SelectHandle, usize, *const u8)], + handles: &mut [( + &dyn SelectHandle, + Option>, + usize, + *const u8, + )], timeout: Timeout, ) -> Option<(Token, usize, *const u8)> { if handles.is_empty() { @@ -202,7 +255,7 @@ fn run_select( let mut token = Token::default(); // Try selecting one of the operations without blocking. - for &(handle, i, ptr) in handles.iter() { + for &(handle, _, i, ptr) in handles.iter() { if handle.try_select(&mut token) { return Some((token, i, ptr)); } @@ -220,11 +273,15 @@ fn run_select( } // Register all operations. - for (handle, i, _) in handles.iter_mut() { + for (handle, state, i, _) in handles.iter_mut() { registered_count += 1; // If registration returns `false`, that means the operation has just become ready. - if handle.register(Operation::hook::<&dyn SelectHandle>(handle), cx) { + if handle.register( + Operation::hook::<&dyn SelectHandle>(handle), + cx, + state.as_ref(), + ) { // Try aborting select. sel = match cx.try_select(Selected::Aborted) { Ok(()) => { @@ -251,7 +308,7 @@ fn run_select( Timeout::Never => None, Timeout::At(when) => Some(when), }; - for &(handle, _, _) in handles.iter() { + for &(handle, _, _, _) in handles.iter() { if let Some(x) = handle.deadline() { deadline = deadline.map(|y| x.min(y)).or(Some(x)); } @@ -262,8 +319,12 @@ fn run_select( } // Unregister all registered operations. - for (handle, _, _) in handles.iter_mut().take(registered_count) { + for (handle, state, _, _) in handles.iter_mut().take(registered_count) { handle.unregister(Operation::hook::<&dyn SelectHandle>(handle)); + + if let Some(state) = state { + state.unpark(); + } } match sel { @@ -271,7 +332,7 @@ fn run_select( Selected::Aborted => { // If an operation became ready during registration, try selecting it. if let Some(index_ready) = index_ready { - for &(handle, i, ptr) in handles.iter() { + for &(handle, _, i, ptr) in handles.iter() { if i == index_ready && handle.try_select(&mut token) { return Some((i, ptr)); } @@ -281,7 +342,7 @@ fn run_select( Selected::Disconnected => {} Selected::Operation(_) => { // Find the selected operation. - for (handle, i, ptr) in handles.iter_mut() { + for (handle, _, i, ptr) in handles.iter_mut() { // Is this the selected operation? if sel == Selected::Operation(Operation::hook::<&dyn SelectHandle>(handle)) { @@ -303,7 +364,7 @@ fn run_select( } // Try selecting one of the operations without blocking. - for &(handle, i, ptr) in handles.iter() { + for &(handle, _, i, ptr) in handles.iter() { if handle.try_select(&mut token) { return Some((token, i, ptr)); } @@ -323,7 +384,12 @@ fn run_select( /// Runs until one of the operations becomes ready, potentially blocking the current thread. fn run_ready( - handles: &mut [(&dyn SelectHandle, usize, *const u8)], + handles: &mut [( + &dyn SelectHandle, + Option>, + usize, + *const u8, + )], timeout: Timeout, ) -> Option { if handles.is_empty() { @@ -348,7 +414,7 @@ fn run_ready( let backoff = Backoff::new(); loop { // Check operations for readiness. - for &(handle, i, _) in handles.iter() { + for &(handle, _, i, _) in handles.iter() { if handle.is_ready() { return Some(i); } @@ -378,12 +444,12 @@ fn run_ready( let mut registered_count = 0; // Begin watching all operations. - for (handle, _, _) in handles.iter_mut() { + for (handle, state, _, _) in handles.iter_mut() { registered_count += 1; let oper = Operation::hook::<&dyn SelectHandle>(handle); // If registration returns `false`, that means the operation has just become ready. - if handle.watch(oper, cx) { + if handle.watch(oper, cx, state.as_ref()) { sel = match cx.try_select(Selected::Operation(oper)) { Ok(()) => Selected::Operation(oper), Err(s) => s, @@ -406,7 +472,7 @@ fn run_ready( Timeout::Never => None, Timeout::At(when) => Some(when), }; - for &(handle, _, _) in handles.iter() { + for &(handle, _, _, _) in handles.iter() { if let Some(x) = handle.deadline() { deadline = deadline.map(|y| x.min(y)).or(Some(x)); } @@ -417,8 +483,11 @@ fn run_ready( } // Unwatch all operations. - for (handle, _, _) in handles.iter_mut().take(registered_count) { + for (handle, state, _, _) in handles.iter_mut().take(registered_count) { handle.unwatch(Operation::hook::<&dyn SelectHandle>(handle)); + if let Some(state) = state { + state.unpark(); + } } match sel { @@ -426,7 +495,7 @@ fn run_ready( Selected::Aborted => {} Selected::Disconnected => {} Selected::Operation(_) => { - for (handle, i, _) in handles.iter_mut() { + for (handle, _, i, _) in handles.iter_mut() { let oper = Operation::hook::<&dyn SelectHandle>(handle); if sel == Selected::Operation(oper) { return Some(*i); @@ -449,7 +518,12 @@ fn run_ready( // This is a private API (exposed inside crossbeam_channel::internal module) that is used by the select macro. #[inline] pub fn try_select<'a>( - handles: &mut [(&'a dyn SelectHandle, usize, *const u8)], + handles: &mut [( + &'a dyn SelectHandle, + Option>, + usize, + *const u8, + )], ) -> Result, TrySelectError> { match run_select(handles, Timeout::Now) { None => Err(TrySelectError), @@ -466,7 +540,12 @@ pub fn try_select<'a>( // This is a private API (exposed inside crossbeam_channel::internal module) that is used by the select macro. #[inline] pub fn select<'a>( - handles: &mut [(&'a dyn SelectHandle, usize, *const u8)], + handles: &mut [( + &'a dyn SelectHandle, + Option>, + usize, + *const u8, + )], ) -> SelectedOperation<'a> { if handles.is_empty() { panic!("no operations have been added to `Select`"); @@ -485,7 +564,12 @@ pub fn select<'a>( // This is a private API (exposed inside crossbeam_channel::internal module) that is used by the select macro. #[inline] pub fn select_timeout<'a>( - handles: &mut [(&'a dyn SelectHandle, usize, *const u8)], + handles: &mut [( + &'a dyn SelectHandle, + Option>, + usize, + *const u8, + )], timeout: Duration, ) -> Result, SelectTimeoutError> { match Instant::now().checked_add(timeout) { @@ -497,7 +581,12 @@ pub fn select_timeout<'a>( /// Blocks until a given deadline, or until one of the operations becomes ready and selects it. #[inline] pub(crate) fn select_deadline<'a>( - handles: &mut [(&'a dyn SelectHandle, usize, *const u8)], + handles: &mut [( + &'a dyn SelectHandle, + Option>, + usize, + *const u8, + )], deadline: Instant, ) -> Result, SelectTimeoutError> { match run_select(handles, Timeout::At(deadline)) { @@ -597,7 +686,12 @@ pub(crate) fn select_deadline<'a>( /// [`ready_timeout`]: Select::ready_timeout pub struct Select<'a> { /// A list of senders and receivers participating in selection. - handles: Vec<(&'a dyn SelectHandle, usize, *const u8)>, + handles: Vec<( + &'a dyn SelectHandle, + Option>, + usize, + *const u8, + )>, /// The next index to assign to an operation. next_index: usize, @@ -643,7 +737,8 @@ impl<'a> Select<'a> { pub fn send(&mut self, s: &'a Sender) -> usize { let i = self.next_index; let ptr = s as *const Sender<_> as *const u8; - self.handles.push((s, i, ptr)); + let state = s.start(); + self.handles.push((s, state, i, ptr)); self.next_index += 1; i } @@ -665,7 +760,8 @@ impl<'a> Select<'a> { pub fn recv(&mut self, r: &'a Receiver) -> usize { let i = self.next_index; let ptr = r as *const Receiver<_> as *const u8; - self.handles.push((r, i, ptr)); + let state = r.start(); + self.handles.push((r, state, i, ptr)); self.next_index += 1; i } @@ -718,7 +814,7 @@ impl<'a> Select<'a> { .handles .iter() .enumerate() - .find(|(_, (_, i, _))| *i == index) + .find(|(_, (_, _, i, _))| *i == index) .expect("no operation with this index") .0; diff --git a/crossbeam-channel/src/select_macro.rs b/crossbeam-channel/src/select_macro.rs index 3b71e1e50..52b038c06 100644 --- a/crossbeam-channel/src/select_macro.rs +++ b/crossbeam-channel/src/select_macro.rs @@ -685,10 +685,16 @@ macro_rules! crossbeam_channel_internal { $default:tt ) => {{ const _LEN: usize = $crate::crossbeam_channel_internal!(@count ($($cases)*)); - let _handle: &dyn $crate::internal::SelectHandle = &$crate::never::<()>(); + + const _STATE: ( + &'static dyn $crate::internal::SelectHandle, + ::core::option::Option<$crate::internal::BlockingState<'static>>, + usize, + *const u8 + ) = (&(), None, 0, ::std::ptr::null()); #[allow(unused_mut)] - let mut _sel = [(_handle, 0, ::std::ptr::null()); _LEN]; + let mut _sel = [_STATE; _LEN]; $crate::crossbeam_channel_internal!( @add @@ -852,7 +858,7 @@ macro_rules! crossbeam_channel_internal { } unbind(_r) }; - $sel[$i] = ($var, $i, $var as *const $crate::Receiver<_> as *const u8); + $sel[$i] = ($var, $crate::internal::SelectHandle::start($var), $i, $var as *const $crate::Receiver<_> as *const u8); $crate::crossbeam_channel_internal!( @add @@ -884,7 +890,7 @@ macro_rules! crossbeam_channel_internal { } unbind(_s) }; - $sel[$i] = ($var, $i, $var as *const $crate::Sender<_> as *const u8); + $sel[$i] = ($var, $crate::internal::SelectHandle::start($var), $i, $var as *const $crate::Sender<_> as *const u8); $crate::crossbeam_channel_internal!( @add diff --git a/crossbeam-channel/src/waker.rs b/crossbeam-channel/src/waker.rs index 7a88c8fdc..8134227aa 100644 --- a/crossbeam-channel/src/waker.rs +++ b/crossbeam-channel/src/waker.rs @@ -1,7 +1,8 @@ //! Waking mechanism for threads blocked on channel operations. +use core::sync::atomic::AtomicU32; use std::ptr; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::Ordering; use std::sync::Mutex; use std::thread::{self, ThreadId}; use std::vec::Vec; @@ -179,8 +180,8 @@ pub(crate) struct SyncWaker { /// The inner `Waker`. inner: Mutex, - /// `true` if the waker is empty. - is_empty: AtomicBool, + /// Atomic state for this waker. + state: WakerState, } impl SyncWaker { @@ -189,58 +190,52 @@ impl SyncWaker { pub(crate) fn new() -> Self { Self { inner: Mutex::new(Waker::new()), - is_empty: AtomicBool::new(true), + state: WakerState::new(), + } + } + + /// Returns a token that can be used to manage the state of a blocking operation. + pub(crate) fn start(&self) -> BlockingState<'_> { + BlockingState { + is_waker: false, + waker: self, } } /// Registers the current thread with an operation. #[inline] - pub(crate) fn register(&self, oper: Operation, cx: &Context) { - let mut inner = self.inner.lock().unwrap(); - inner.register(oper, cx); - self.is_empty.store( - inner.selectors.is_empty() && inner.observers.is_empty(), - Ordering::SeqCst, - ); + pub(crate) fn register(&self, oper: Operation, cx: &Context, state: &BlockingState<'_>) { + self.inner.lock().unwrap().register(oper, cx); + self.state.park(state.is_waker); } /// Unregisters an operation previously registered by the current thread. #[inline] pub(crate) fn unregister(&self, oper: Operation) -> Option { - let mut inner = self.inner.lock().unwrap(); - let entry = inner.unregister(oper); - self.is_empty.store( - inner.selectors.is_empty() && inner.observers.is_empty(), - Ordering::SeqCst, - ); - entry + self.inner.lock().unwrap().unregister(oper) } /// Attempts to find one thread (not the current one), select its operation, and wake it up. #[inline] pub(crate) fn notify(&self) { - if !self.is_empty.load(Ordering::SeqCst) { - let mut inner = self.inner.lock().unwrap(); - if !self.is_empty.load(Ordering::SeqCst) { - inner.try_select(); - inner.notify(); - self.is_empty.store( - inner.selectors.is_empty() && inner.observers.is_empty(), - Ordering::SeqCst, - ); - } + if self.state.try_notify() { + self.notify_one() } } - /// Registers an operation waiting to be ready. + // Finds a thread (not the current one), select its operation, and wake it up. #[inline] - pub(crate) fn watch(&self, oper: Operation, cx: &Context) { + pub(crate) fn notify_one(&self) { let mut inner = self.inner.lock().unwrap(); - inner.watch(oper, cx); - self.is_empty.store( - inner.selectors.is_empty() && inner.observers.is_empty(), - Ordering::SeqCst, - ); + inner.try_select(); + inner.notify(); + } + + /// Registers an operation waiting to be ready. + #[inline] + pub(crate) fn watch(&self, oper: Operation, cx: &Context, state: &BlockingState<'_>) { + self.inner.lock().unwrap().watch(oper, cx); + self.state.park(state.is_waker); } /// Unregisters an operation waiting to be ready. @@ -248,10 +243,6 @@ impl SyncWaker { pub(crate) fn unwatch(&self, oper: Operation) { let mut inner = self.inner.lock().unwrap(); inner.unwatch(oper); - self.is_empty.store( - inner.selectors.is_empty() && inner.observers.is_empty(), - Ordering::SeqCst, - ); } /// Notifies all threads that the channel is disconnected. @@ -259,17 +250,131 @@ impl SyncWaker { pub(crate) fn disconnect(&self) { let mut inner = self.inner.lock().unwrap(); inner.disconnect(); - self.is_empty.store( - inner.selectors.is_empty() && inner.observers.is_empty(), - Ordering::SeqCst, - ); } } impl Drop for SyncWaker { #[inline] fn drop(&mut self) { - debug_assert!(self.is_empty.load(Ordering::SeqCst)); + debug_assert!(!self.state.has_waiters()); + } +} + +/// A guard that manages the state of a blocking operation. +#[derive(Clone)] +#[allow(missing_debug_implementations)] +pub struct BlockingState<'a> { + /// True if this thread is the waker thread, meaning it must + /// try to notify waiters after it completes. + is_waker: bool, + + waker: &'a SyncWaker, +} + +impl BlockingState<'_> { + /// Reset the state after waking up from parking. + #[inline] + pub(crate) fn unpark(&mut self) { + self.is_waker = self.waker.state.unpark(); + } +} + +impl Drop for BlockingState<'_> { + fn drop(&mut self) { + if self.is_waker && self.waker.state.drop_waker() { + self.waker.notify_one(); + } + } +} + +const NOTIFIED: u32 = 0b001; +const WAKER: u32 = 0b010; + +/// The state of a `SyncWaker`. +struct WakerState { + state: AtomicU32, +} + +impl WakerState { + /// Initialize the waker state. + fn new() -> WakerState { + WakerState { + state: AtomicU32::new(0), + } + } + + /// Returns whether or not a waiter needs to be notified. + fn try_notify(&self) -> bool { + // because storing a value in the channel is also sequentially consistent, + // this creates a total order between storing a value and registering a waiter. + let state = self.state.load(Ordering::SeqCst); + + // if a notification is already set, the waker thread will take care + // of further notifications. otherwise we have to notify if there are waiters + if (state >> WAKER) > 0 && (state & NOTIFIED == 0) { + return self + .state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { + // set the notification if there are waiters and it is not already set + if (state >> WAKER) > 0 && (state & NOTIFIED == 0) { + Some(state | (WAKER | NOTIFIED)) + } else { + None + } + }) + .is_ok(); + } + + false + } + + /// Get ready for this waker to park. The channel should be checked after calling this + /// method, and before parking. + fn park(&self, waker: bool) { + // increment the waiter count. if we are the waker thread, we also have to remove the + // notification to allow other waiters to be notified after we park + let update = (1_u32 << WAKER).wrapping_sub(u32::from(waker)); + self.state.fetch_add(update, Ordering::SeqCst); + } + + /// Remove this waiter from the waker state after it was unparked. + /// + /// Returns `true` if this thread became the waking thread and must call `drop_waker` + /// after it completes it's operation. + fn unpark(&self) -> bool { + self.state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { + // decrement the waiter count and consume the waker token + Some((state - (1 << WAKER)) & !WAKER) + }) + // did we consume the token and become the waker thread? + .map(|state| state & WAKER != 0) + .unwrap() + } + + /// Called by the waking thread after completing it's operation. + /// + /// Returns `true` if a waiter should be notified. + fn drop_waker(&self) -> bool { + self.state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { + // if there are waiters, set the waker token and wake someone, transferring the + // waker thread. otherwise unset the notification so new waiters can synchronize + // with new notifications + Some(if (state >> WAKER) > 0 { + state | WAKER + } else { + state.wrapping_sub(NOTIFIED) + }) + }) + // were there waiters? + .map(|state| (state >> WAKER) > 0) + .unwrap() + } + + /// Returns `true` if there are active waiters. + fn has_waiters(&self) -> bool { + (self.state.load(Ordering::Relaxed) >> WAKER) > 0 } } diff --git a/crossbeam-utils/src/backoff.rs b/crossbeam-utils/src/backoff.rs index 9729ce695..9cfce48f7 100644 --- a/crossbeam-utils/src/backoff.rs +++ b/crossbeam-utils/src/backoff.rs @@ -269,7 +269,16 @@ impl Backoff { /// [`AtomicBool`]: std::sync::atomic::AtomicBool #[inline] pub fn is_completed(&self) -> bool { - self.step.get() > YIELD_LIMIT + #[cfg(not(test))] + { + self.step.get() > YIELD_LIMIT + } + + // avoid spinning during tests, which can hide bugs in the waker + #[cfg(test)] + { + true + } } }