diff --git a/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/internal/ByteChannelReplay.kt b/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/internal/ByteChannelReplay.kt index e9c9bcfba57..e156338bf44 100644 --- a/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/internal/ByteChannelReplay.kt +++ b/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/internal/ByteChannelReplay.kt @@ -23,14 +23,16 @@ internal class ByteChannelReplay(private val origin: ByteReadChannel) { if (copyTask == null) { copyTask = CopyFromSourceTask() if (!content.compareAndSet(null, copyTask)) { - copyTask = content.value + // second thread, read from copy + copyTask = content.value!! } else { + // first thread, read from origin return copyTask.start() } } return GlobalScope.writer { - val body = copyTask!!.awaitImpatiently() + val body = copyTask.awaitImpatiently() channel.writeFully(body) }.channel } @@ -44,12 +46,9 @@ internal class ByteChannelReplay(private val origin: ByteReadChannel) { private inner class CopyFromSourceTask( val savedResponse: CompletableDeferred = CompletableDeferred() ) { - lateinit var writerJob: WriterJob + private val writerJob: WriterJob by lazy { receiveBody() } - fun start(): ByteReadChannel { - writerJob = receiveBody() - return writerJob.channel - } + fun start() = writerJob.channel @OptIn(DelicateCoroutinesApi::class) fun receiveBody(): WriterJob = GlobalScope.writer(Dispatchers.Unconfined) { diff --git a/ktor-client/ktor-client-core/common/test/ByteChannelReplayTest.kt b/ktor-client/ktor-client-core/common/test/ByteChannelReplayTest.kt new file mode 100644 index 00000000000..44a2827c880 --- /dev/null +++ b/ktor-client/ktor-client-core/common/test/ByteChannelReplayTest.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ +import io.ktor.client.plugins.internal.* +import io.ktor.utils.io.* +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +internal class ByteChannelReplayTest { + + private val size = 1024 * 1024 + 1 + private val expectedByte = 'A'.code.toByte() + private val expected = ByteArray(size).apply { fill(expectedByte) } + private lateinit var channelReplay: ByteChannelReplay + + @BeforeTest + fun setup() { + channelReplay = ByteChannelReplay(ByteReadChannel(expected)) + } + + @Test + fun readFirst() = runTest { + val first = channelReplay.replay() + assertRead(first) + val second = channelReplay.replay() + assertRead(second) + } + + @Test + fun readSecond() = runTest { + val first = channelReplay.replay() + val second = channelReplay.replay() + assertRead(second) + assertTrue(first.isClosedForRead) + } + + @Test + fun readABunch() = runTest { + val jobs = (0..10).map { + launch { + val readChannel = channelReplay.replay() + yield() + try { + assertRead(readChannel) + } catch (e: Exception) { + assertEquals("Save body abandoned", e.message) + } + } + } + joinAll(*jobs.toTypedArray()) + } + + private suspend fun assertRead(readChannel: ByteReadChannel) { + repeat(size) { i -> + assertEquals(expectedByte, readChannel.readByte(), "Incorrect byte at index $i") + } + assertTrue(readChannel.isClosedForRead || readChannel.exhausted()) + } +}