diff --git a/std/shared/src/main/scala/cats/effect/std/Mutex.scala b/std/shared/src/main/scala/cats/effect/std/Mutex.scala index dd37867b21..9e761b1016 100644 --- a/std/shared/src/main/scala/cats/effect/std/Mutex.scala +++ b/std/shared/src/main/scala/cats/effect/std/Mutex.scala @@ -63,7 +63,7 @@ object Mutex { /** * Creates a new `Mutex`. */ - def apply[F[_]](implicit F: Concurrent[F]): F[Mutex[F]] = + def apply[F[_]](implicit F: GenConcurrent[F, _]): F[Mutex[F]] = Ref .of[F, ConcurrentImpl.LockQueueCell]( // Initialize the state with an already completed cell. @@ -85,7 +85,7 @@ object Mutex { private final class ConcurrentImpl[F[_]]( state: Ref[F, ConcurrentImpl.LockQueueCell] )( - implicit F: Concurrent[F] + implicit F: GenConcurrent[F, _] ) extends Mutex[F] { // Awakes whoever is waiting for us with the next cell in the queue. private def awakeCell( @@ -162,7 +162,7 @@ object Mutex { // Represents a waiting cell in the queue. private[Mutex] final type WaitingCell[F[_]] = Deferred[F, LockQueueCell] - private[Mutex] def LockQueueCell[F[_]](implicit F: Concurrent[F]): F[WaitingCell[F]] = + private[Mutex] def LockQueueCell[F[_]](implicit F: GenConcurrent[F, _]): F[WaitingCell[F]] = Deferred[F, LockQueueCell] } diff --git a/std/shared/src/main/scala/cats/effect/std/Queue.scala b/std/shared/src/main/scala/cats/effect/std/Queue.scala index e1358e2291..f64e557d82 100644 --- a/std/shared/src/main/scala/cats/effect/std/Queue.scala +++ b/std/shared/src/main/scala/cats/effect/std/Queue.scala @@ -18,7 +18,16 @@ package cats package effect package std -import cats.effect.kernel.{Async, Cont, Deferred, GenConcurrent, MonadCancelThrow, Poll, Ref} +import cats.effect.kernel.{ + Async, + Cont, + Concurrent, + Deferred, + GenConcurrent, + MonadCancelThrow, + Poll, + Ref +} import cats.effect.kernel.syntax.all._ import cats.syntax.all._ @@ -71,7 +80,7 @@ object Queue { * @return * an empty, bounded queue */ - def bounded[F[_], A](capacity: Int)(implicit F: GenConcurrent[F, _]): F[Queue[F, A]] = { + def bounded[F[_], A](capacity: Int)(implicit F: Concurrent[F]): F[Queue[F, A]] = { assertNonNegative(capacity) // async queue can't handle capacity == 1 and allocates eagerly, so cap at 64k @@ -115,8 +124,12 @@ object Queue { * @return * a synchronous queue */ - def synchronous[F[_], A](implicit F: GenConcurrent[F, _]): F[Queue[F, A]] = - F.ref(SyncState.empty[F, A]).map(new Synchronous(_)) + def synchronous[F[_], A](implicit F: Concurrent[F]): F[Queue[F, A]] = + ( + Mutex[F], + Mutex[F], + F.ref[SyncState](SyncState.Empty) + ).mapN(new Synchronous(_, _, _)) /** * Constructs an empty, unbounded queue for `F` data types that are @@ -182,82 +195,127 @@ object Queue { s"$name queue capacity must be positive, was: $capacity") else () - private final class Synchronous[F[_], A](stateR: Ref[F, SyncState[F, A]])( - implicit F: GenConcurrent[F, _]) - extends Queue[F, A] { + private final class Synchronous[F[_], A]( + offers: Mutex[F], + takers: Mutex[F], + state: Ref[F, SyncState] + )( + implicit F: Concurrent[F] + ) extends Queue[F, A] { + override def take: F[A] = + takers.lock.surround { + Deferred[F, SyncState.OffererWaiting[F, A]].flatMap { df => + F.uncancelable { poll => + state + .modify { oldState => + if (oldState eq SyncState.Empty) { + val newState: SyncState.TakerWaiting[F, A] = df + val program = F.onCancel( + poll(df.get), + state.getAndSet(SyncState.Empty).flatMap { s => + if (s eq newState) + F.unit + else + s.asInstanceOf[SyncState.OffererWaiting[F, A]]._2.complete(false).void + } + ) - def offer(a: A): F[Unit] = - F.deferred[Unit] flatMap { latch => - F uncancelable { poll => - val modificationF = stateR modify { - case SyncState(offerers, takers) if takers.nonEmpty => - val (taker, tail) = takers.dequeue - SyncState(offerers, tail) -> taker.complete(a).void - - case SyncState(offerers, takers) => - val cleanupF = stateR update { - case SyncState(offerers, takers) => - SyncState(offerers.filter(_._2 ne latch), takers) + newState -> program + } else { + SyncState.Empty -> F.pure( + oldState.asInstanceOf[SyncState.OffererWaiting[F, A]] + ) + } + } + .flatten + .flatMap { + case (a, offererSignal) => + state.set(SyncState.Empty) >> offererSignal.complete(true).as(a) } - - SyncState(offerers.enqueue((a, latch)), takers) -> - poll(latch.get).onCancel(cleanupF) } - - modificationF.flatten } } - def tryOffer(a: A): F[Boolean] = - stateR.flatModify { - case SyncState(offerers, takers) if takers.nonEmpty => - val (taker, tail) = takers.dequeue - SyncState(offerers, tail) -> taker.complete(a).as(true) - - case st => - st -> F.pure(false) + override def tryTake: F[Option[A]] = + state.access.flatMap { + case (oldState, setter) => + if ((oldState eq SyncState.Empty) || + oldState.isInstanceOf[SyncState.TakerWaiting[F, A]]) + F.pure(None) + else { + val (a, offererSignal) = oldState.asInstanceOf[SyncState.OffererWaiting[F, A]] + setter(SyncState.Empty).flatMap { + case true => + offererSignal.complete(true).as(Some(a)) + + case false => + F.pure(None) + } + } } - val take: F[A] = - F.deferred[A] flatMap { latch => - F uncancelable { poll => - val modificationF = stateR modify { - case SyncState(offerers, takers) if offerers.nonEmpty => - val ((value, offerer), tail) = offerers.dequeue - SyncState(tail, takers) -> offerer.complete(()).as(value) - - case SyncState(offerers, takers) => - val cleanupF = stateR update { - case SyncState(offerers, takers) => - SyncState(offerers, takers.filter(_ ne latch)) - } - - SyncState(offerers, takers.enqueue(latch)) -> poll(latch.get).onCancel(cleanupF) + override def offer(a: A): F[Unit] = + offers.lock.surround { + def loop(): F[Unit] = Deferred[F, Boolean].flatMap { df => + F.uncancelable { poll => + val newState: SyncState.OffererWaiting[F, A] = (a, df) + + val consumed = F.onCancel( + poll(df.get).flatMap { + case true => F.unit + case false => loop() + }, + state.update(s => if (s eq newState) SyncState.Empty else s) + ) + + state.getAndSet(newState).flatMap { oldState => + if (oldState eq SyncState.Empty) + consumed + else + oldState + .asInstanceOf[SyncState.TakerWaiting[F, A]] + .complete(newState) >> consumed + } } - - modificationF.flatten } + + loop() } - val tryTake: F[Option[A]] = - stateR.flatModify { - case SyncState(offerers, takers) if offerers.nonEmpty => - val ((value, offerer), tail) = offerers.dequeue - SyncState(tail, takers) -> offerer.complete(()).as(value.some) + override def tryOffer(a: A): F[Boolean] = + state.access.flatMap { + case (oldState, setter) => + if ((oldState eq SyncState.Empty) || + oldState.isInstanceOf[SyncState.OffererWaiting[F, A]]) + F.pure(false) + else { + setter(SyncState.Empty).flatMap { + case true => + Deferred[F, Boolean].flatMap { df => + // This won't block for long since we complete quickly in this handoff. + F.uncancelable { poll => + oldState.asInstanceOf[SyncState.TakerWaiting[F, A]].complete(a -> df) >> + poll(df.get) + } + } - case st => - st -> none[A].pure[F] + case false => + F.pure(false) + } + } } - val size: F[Int] = F.pure(0) + override final val size: F[Int] = F.pure(0) } - private final case class SyncState[F[_], A]( - offerers: ScalaQueue[(A, Deferred[F, Unit])], - takers: ScalaQueue[Deferred[F, A]]) - + private type SyncState = AnyRef private object SyncState { - def empty[F[_], A]: SyncState[F, A] = SyncState(ScalaQueue(), ScalaQueue()) + type Empty = SyncState + final val Empty: Empty = null + + type OffererWaiting[F[_], A] = (A, Deferred[F, Boolean]) + + type TakerWaiting[F[_], A] = Deferred[F, OffererWaiting[F, A]] } private sealed abstract class AbstractQueue[F[_], A]( diff --git a/tests/shared/src/test/scala/cats/effect/std/QueueSpec.scala b/tests/shared/src/test/scala/cats/effect/std/QueueSpec.scala index d7ada2d16e..a04100122c 100644 --- a/tests/shared/src/test/scala/cats/effect/std/QueueSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/std/QueueSpec.scala @@ -31,7 +31,7 @@ import org.specs2.specification.core.Fragments import scala.collection.immutable.{Queue => ScalaQueue} import scala.concurrent.duration._ -class BoundedQueueSpec extends BaseSpec with QueueTests[Queue] with DetectPlatform { +final class QueueSpec extends BaseSpec with QueueTests[Queue] with DetectPlatform { "BoundedQueue (concurrent)" should { boundedQueueTests(i => if (i == 0) Queue.synchronous else Queue.boundedForConcurrent(i)) @@ -51,7 +51,7 @@ class BoundedQueueSpec extends BaseSpec with QueueTests[Queue] with DetectPlatfo boundedQueueTests(Queue.bounded[IO, Int](_).map(_.mapK(FunctionK.id))) } - "synchronous queue" should { + "SynchronousQueue" should { "respect fifo order" in ticked { implicit ticker => val test = for { q <- Queue.synchronous[IO, Int] @@ -72,6 +72,38 @@ class BoundedQueueSpec extends BaseSpec with QueueTests[Queue] with DetectPlatfo test must completeAs(0.until(5).toList) } + "not lose offer when taker is canceled during exchange" in real { + val test = for { + q <- Queue.synchronous[IO, Unit] + latch <- CountDownLatch[IO](2) + offererDone <- IO.ref(false) + + _ <- (latch.release *> latch.await *> q.offer(())) + .guarantee(offererDone.set(true)) + .start + taker <- (latch.release *> latch.await *> q.take).start + + _ <- latch.await + _ <- taker.cancel + + // we should either have received the value successfully, or we left the value in queue + // what we *don't* want is to remove the value and then lose it due to cancelation + oc <- taker.join + + _ <- + if (oc.isCanceled) { + // we (maybe) hit the race condition + // if we lost the value, q.take will hang + offererDone.get.flatMap(b => IO(b must beFalse)) *> q.take + } else { + // we definitely didn't hit the race condition, because we got the value in taker + IO.unit + } + } yield ok + + test.parReplicateA_(10000).as(ok) + } + "not lose takers when offerer is canceled and there are no other takers" in real { val test = for { q <- Queue.synchronous[IO, Unit]