diff --git a/core/src/it/scala/com/wixpress/dst/greyhound/core/parallel/ParallelConsumerIT.scala b/core/src/it/scala/com/wixpress/dst/greyhound/core/parallel/ParallelConsumerIT.scala index 6ef6de44..e9651995 100644 --- a/core/src/it/scala/com/wixpress/dst/greyhound/core/parallel/ParallelConsumerIT.scala +++ b/core/src/it/scala/com/wixpress/dst/greyhound/core/parallel/ParallelConsumerIT.scala @@ -104,49 +104,46 @@ class ParallelConsumerIT extends BaseTestWithSharedEnv[Env, TestResources] { fastMessages = allMessages - 1 drainTimeout = 5.seconds - keyWithSlowHandling = "slow-key" - numProcessedMessges <- Ref.make[Int](0) - fastMessagesLatch <- CountDownLatch.make(fastMessages) + numProcessedMessages <- Ref.make[Int](0) + fastMessagesLatch <- CountDownLatch.make(fastMessages) randomKeys <- ZIO.foreach(1 to fastMessages)(i => randomKey(i.toString)).map(_.toSeq) fastRecords = randomKeys.map { key => recordWithKey(topic, key, partition) } - slowRecord = recordWithKey(topic, keyWithSlowHandling, partition) + slowRecord = recordWithoutKey(topic, partition) finishRebalance <- Promise.make[Nothing, Unit] // handler that sleeps only on the slow key - handler = RecordHandler { cr: ConsumerRecord[Chunk[Byte], Chunk[Byte]] => - (cr.key match { - case Some(k) if k == Chunk.fromArray(keyWithSlowHandling.getBytes) => - // make sure the handler doesn't finish before the rebalance is done, including drain timeout - finishRebalance.await *> ZIO.sleep(drainTimeout + 1.second) - case _ => fastMessagesLatch.countDown - }) *> numProcessedMessges.update(_ + 1) - } - _ <- - for { - consumer <- makeParallelConsumer(handler, kafka, topic, group, cId, drainTimeout = drainTimeout, startPaused = true) - _ <- produceRecords(producer, Seq(slowRecord)) - _ <- produceRecords(producer, fastRecords) - _ <- ZIO.sleep(2.seconds) - // produce is done synchronously to make sure all records are produced before consumer starts, so all records are polled at once - _ <- consumer.resume - _ <- fastMessagesLatch.await - _ <- ZIO.sleep(3.second) // sleep to ensure commit is done before rebalance - // start another consumer to trigger a rebalance before slow handler is done - _ <- makeParallelConsumer( - handler, - kafka, - topic, - group, - cId, - drainTimeout = drainTimeout, - onAssigned = _ => finishRebalance.succeed() - ) - } yield () - - _ <- eventuallyZ(numProcessedMessges.get, 25.seconds)(_ == allMessages) + handler = RecordHandler { cr: ConsumerRecord[Chunk[Byte], Chunk[Byte]] => + (cr.key match { + case Some(_) => + fastMessagesLatch.countDown + case None => + // make sure the handler doesn't finish before the rebalance is done, including drain timeout + finishRebalance.await *> ZIO.sleep(drainTimeout + 5.second) + }) *> numProcessedMessages.update(_ + 1) + } + consumer <- makeParallelConsumer(handler, kafka, topic, group, cId, drainTimeout = drainTimeout, startPaused = true) + _ <- produceRecords(producer, Seq(slowRecord)) + _ <- produceRecords(producer, fastRecords) + _ <- ZIO.sleep(2.seconds) + // produce is done synchronously to make sure all records are produced before consumer starts, so all records are polled at once + _ <- consumer.resume + _ <- fastMessagesLatch.await + _ <- ZIO.sleep(3.second) // sleep to ensure commit is done before rebalance + // start another consumer to trigger a rebalance before slow handler is done + _ <- makeParallelConsumer( + handler, + kafka, + topic, + group, + cId, + drainTimeout = drainTimeout, + onAssigned = _ => finishRebalance.succeed() + ) + + _ <- eventuallyZ(numProcessedMessages.get, 25.seconds)(_ == allMessages) } yield { ok } @@ -319,6 +316,9 @@ class ParallelConsumerIT extends BaseTestWithSharedEnv[Env, TestResources] { private def recordWithKey(topic: String, key: String, partition: Int) = ProducerRecord(topic, "", Some(key), partition = Some(partition)) + private def recordWithoutKey(topic: String, partition: Int) = + ProducerRecord(topic, "", None, partition = Some(partition)) + private def randomKey(prefix: String) = randomId.map(r => s"$prefix-$r") } diff --git a/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/Dispatcher.scala b/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/Dispatcher.scala index d35b6788..44441099 100644 --- a/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/Dispatcher.scala +++ b/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/Dispatcher.scala @@ -196,15 +196,18 @@ object Dispatcher { ZIO .foreachParDiscard(workers) { case (partition, worker) => - report(StoppingWorker(group, clientId, partition, drainTimeout.toMillis, consumerAttributes)) *> - workersShutdownRef.get.flatMap(_.get(partition).fold(ZIO.unit)(promise => promise.onShutdown.shuttingDown)) *> - worker.shutdown - .catchSomeCause { - case _: Cause[InterruptedException] => ZIO.unit - } // happens on revoke - must not fail on it so we have visibility to worker completion - .timed - .map(_._1) - .flatMap(duration => report(WorkerStopped(group, clientId, partition, duration.toMillis, consumerAttributes))) + for { + _ <- report(StoppingWorker(group, clientId, partition, drainTimeout.toMillis, consumerAttributes)) + workersShutdownMap <- workersShutdownRef.get + _ <- workersShutdownMap.get(partition).fold(ZIO.unit)(promise => promise.onShutdown.shuttingDown) + duration <- worker.shutdown + .catchSomeCause { + case _: Cause[InterruptedException] => ZIO.unit + } // happens on revoke - must not fail on it so we have visibility to worker completion + .timed + .map(_._1) + _ <- report(WorkerStopped(group, clientId, partition, duration.toMillis, consumerAttributes)) + } yield () } .resurrect .ignore @@ -324,7 +327,7 @@ object Dispatcher { override def shutdown: URIO[Any, Unit] = for { _ <- internalState.update(_.shutdown).commit - timeout <- fiber.join.ignore.disconnect.timeout(drainTimeout) + timeout <- fiber.join.ignore.interruptible.timeout(drainTimeout) _ <- ZIO.when(timeout.isEmpty)(fiber.interruptFork) } yield () @@ -404,19 +407,26 @@ object Dispatcher { case DispatcherState.Running => queue.poll.flatMap { case Some(record) => - report(TookRecordFromQueue(record, group, clientId, consumerAttributes)) *> - ZIO - .attempt(currentTimeMillis()) - .flatMap(t => internalState.updateAndGet(_.startedWith(t)).commit) - .tapBoth( - e => report(FailToUpdateCurrentExecutionStarted(record, group, clientId, consumerAttributes, e)), - t => report(CurrentExecutionStartedEvent(partition, group, clientId, t.currentExecutionStarted)) - ) *> handle(record).interruptible.ignore *> isActive(internalState) - case None => isActive(internalState).delay(5.millis) + for { + _ <- report(TookRecordFromQueue(record, group, clientId, consumerAttributes)) + clock <- ZIO.clock + executionStartTime <- clock.currentTime(TimeUnit.MILLISECONDS) + _ <- internalState + .updateAndGet(_.startedWith(executionStartTime)) + .commit + _ <- report(CurrentExecutionStartedEvent(partition, group, clientId, Some(executionStartTime))) + _ <- handle(record).interruptible.ignore + active <- isActive(internalState) + } yield active + case None => + isActive(internalState).delay(5.millis) } case DispatcherState.Paused(resume) => - report(WorkerWaitingForResume(group, clientId, partition, consumerAttributes)) *> resume.await.timeout(30.seconds) *> - isActive(internalState) + for { + _ <- report(WorkerWaitingForResume(group, clientId, partition, consumerAttributes)) + _ <- resume.await.timeout(30.seconds) + active <- isActive(internalState) + } yield active case DispatcherState.ShuttingDown => ZIO.succeed(false) } diff --git a/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoop.scala b/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoop.scala index 4c453446..02c3564c 100644 --- a/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoop.scala +++ b/core/src/main/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoop.scala @@ -99,22 +99,21 @@ private[greyhound] class BatchEventLoopImpl[R]( ) ) - private def pollAndHandle()(implicit trace: Trace): URIO[R, Unit] = for { - _ <- pauseAndResume().provide(ZLayer.succeed(capturedR)) - records <- - consumer - .poll(config.fetchTimeout) - .provide(ZLayer.succeed(capturedR)) - .catchAll(_ => ZIO.succeed(Nil)) - .flatMap(records => seekRequests.get.map(seeks => records.filterNot(record => seeks.keys.toSet.contains(record.topicPartition)))) - _ <- handleRecords(records).timed - .tap { case (duration, _) => report(FullBatchHandled(clientId, group, records.toSeq, duration, consumerAttributes)) } + private def pollAndHandle()(implicit trace: Trace): URIO[R with GreyhoundMetrics, Unit] = for { + _ <- pauseAndResume().ignore + allRecords <- consumer + .poll(config.fetchTimeout) + .catchAll(_ => ZIO.succeed(Nil)) + seeks <- seekRequests.get.map(_.keySet) + records = allRecords.filterNot(record => seeks.contains(record.topicPartition)) + _ <- handleRecords(records).timed + .tap { case (duration, _) => report(FullBatchHandled(clientId, group, records.toSeq, duration, consumerAttributes)) } } yield () private def pauseAndResume()(implicit trace: Trace) = for { pr <- elState.shouldPauseAndResume() - _ <- ZIO.when(pr.toPause.nonEmpty)((consumer.pause(pr.toPause) *> elState.partitionsPaused(pr.toPause)).ignore) - _ <- ZIO.when(pr.toResume.nonEmpty)((consumer.resume(pr.toResume) *> elState.partitionsResumed(pr.toResume)).ignore) + _ <- ZIO.when(pr.toPause.nonEmpty)(consumer.pause(pr.toPause) *> elState.partitionsPaused(pr.toPause)) + _ <- ZIO.when(pr.toResume.nonEmpty)(consumer.resume(pr.toResume) *> elState.partitionsResumed(pr.toResume)) } yield () private def handleRecords(polled: Records)(implicit trace: Trace): ZIO[R, Nothing, Unit] = { @@ -512,10 +511,10 @@ private[greyhound] class BatchEventLoopState( partitionsPaused(pauseResume.toPause) *> partitionsResumed(pauseResume.toResume) def shouldPauseAndResume[R]()(implicit trace: Trace): URIO[R, PauseResume] = for { - pending <- allPending + pending <- allPending.map(_.keySet) paused <- pausedPartitions - toPause = pending.keySet -- paused - toResume = paused -- pending.keySet + toPause = pending -- paused + toResume = paused -- pending } yield PauseResume(toPause, toResume) def appendPending(records: Consumer.Records)(implicit trace: Trace): UIO[Unit] = { diff --git a/core/src/test/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoopTest.scala b/core/src/test/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoopTest.scala index bb1144cc..fda607f6 100644 --- a/core/src/test/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoopTest.scala +++ b/core/src/test/scala/com/wixpress/dst/greyhound/core/consumer/batched/BatchEventLoopTest.scala @@ -52,7 +52,7 @@ class BatchEventLoopTest extends JUnitRunnableSpec { ZIO.scoped(BatchEventLoop.make(group, ConsumerSubscription.topics(topics: _*), consumer, handler, clientId, retry).flatMap { loop => for { - _ <- ZIO.succeed(println(s"Should not retry for retry: $retry, cause: $cause")) + _ <- ZIO.debug(s"Should not retry for retry: $retry, cause: $cause") _ <- givenHandleError(failOnPartition(0, cause)) _ <- givenRecords(consumerRecords) handledRecords <- handled.await(_.nonEmpty) @@ -79,7 +79,7 @@ class BatchEventLoopTest extends JUnitRunnableSpec { ZIO.scoped(BatchEventLoop.make(group, ConsumerSubscription.topics(topics: _*), consumer, handler, clientId, Some(retry)).flatMap { loop => for { - _ <- ZIO.succeed(println(s"Should retry for cause: $cause")) + _ <- ZIO.debug(s"Should retry for cause: $cause") _ <- givenHandleError(failOnPartition(0, cause)) _ <- givenRecords(consumerRecords) handled1 <- handled.await(_.nonEmpty) @@ -153,13 +153,14 @@ class BatchEventLoopTest extends JUnitRunnableSpec { val consumer = new EmptyConsumer { override def poll(timeout: Duration)(implicit trace: Trace): Task[Records] = - queue.take + queue.take.interruptible .timeout(timeout) .map(_.getOrElse(Iterable.empty)) - .tap(r => ZIO.succeed(println(s"poll($timeout): $r"))) + .tap(r => ZIO.debug(s"poll($timeout): $r")) + override def commit(offsets: Map[TopicPartition, Offset])(implicit trace: Trace): Task[Unit] = { - ZIO.succeed(println(s"commit($offsets)")) *> committedOffsetsRef.update(_ ++ offsets) + ZIO.debug(s"commit($offsets)") *> committedOffsetsRef.update(_ ++ offsets) } override def commitWithMetadata(offsetsAndMetadata: Map[TopicPartition, OffsetAndMetadata])( @@ -190,14 +191,16 @@ class BatchEventLoopTest extends JUnitRunnableSpec { } } - val handler = new BatchRecordHandler[Any, Throwable, Chunk[Byte], Chunk[Byte]] { - override def handle(records: RecordBatch): ZIO[Any, HandleError[Throwable], Any] = { - ZIO.succeed(println(s"handle($records)")) *> - (handlerErrorsRef.get.flatMap(he => he(records.records).fold(ZIO.unit: IO[HandleError[Throwable], Unit])(ZIO.failCause(_))) *> - handled.update(_ :+ records.records)) - .tapErrorCause(e => ZIO.succeed(println(s"handle failed with $e, records: $records"))) - .tap(_ => ZIO.succeed(println(s"handled $records"))) - } + val handler = new BatchRecordHandler[Any, Throwable, Chunk[Byte], Chunk[Byte]] { + override def handle(records: RecordBatch): ZIO[Any, HandleError[Throwable], Any] = for { + _ <- ZIO.debug(s"handle($records)") + he <- handlerErrorsRef.get + _ <- he(records.records).fold(ZIO.unit: IO[HandleError[Throwable], Unit])(ZIO.failCause(_)) + _ <- handled + .update(_ :+ records.records) + .tapErrorCause(e => ZIO.debug(s"handle failed with $e, records: $records")) + .tap(_ => ZIO.debug(s"handled $records")) + } yield () } def givenRecords(records: Seq[Consumer.Record]) = queue.offer(records)