Skip to content

Commit

Permalink
KTOR-7952 Fix for uninitialized writerJob property
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhham committed Dec 17, 2024
1 parent 37a9a0f commit 38caba7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.io.*

internal class ByteChannelReplay(private val origin: ByteReadChannel) {
Expand All @@ -23,14 +24,16 @@ internal class ByteChannelReplay(private val origin: ByteReadChannel) {
if (copyTask == null) {
copyTask = CopyFromSourceTask()
if (!content.compareAndSet(null, copyTask)) {
copyTask = content.value
// second thread, read from copy
copyTask = content.value!!
} else {
// first thread, read from origin
return copyTask.start()
}
}

return GlobalScope.writer {
val body = copyTask!!.awaitImpatiently()
val body = copyTask.awaitImpatiently()
channel.writeFully(body)
}.channel
}
Expand All @@ -44,12 +47,9 @@ internal class ByteChannelReplay(private val origin: ByteReadChannel) {
private inner class CopyFromSourceTask(
val savedResponse: CompletableDeferred<ByteArray> = CompletableDeferred()
) {
lateinit var writerJob: WriterJob
private val writerJob: WriterJob by lazy { receiveBody() }

fun start(): ByteReadChannel {
writerJob = receiveBody()
return writerJob.channel
}
fun start() = writerJob.channel

@OptIn(DelicateCoroutinesApi::class)
fun receiveBody(): WriterJob = GlobalScope.writer(Dispatchers.Unconfined) {
Expand Down
70 changes: 70 additions & 0 deletions ktor-client/ktor-client-core/common/test/ByteChannelReplayTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import io.ktor.client.plugins.internal.*
import io.ktor.utils.io.*
import kotlinx.coroutines.joinAll
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

/*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

internal class ByteChannelReplayTest {

private val size = 1024 * 1024 + 1
private val expectedByte = 'A'.code.toByte()
private val expected = ByteArray(size).apply { fill(expectedByte) }
private lateinit var channelReplay: ByteChannelReplay

@BeforeTest
fun setup() {
channelReplay = ByteChannelReplay(ByteReadChannel(expected))
}

@Test
fun readFirst() = runTest {
val first = channelReplay.replay()
assertRead(first)
val second = channelReplay.replay()
assertRead(second)
assertTrue(second.isClosedForRead)
assertTrue(first.isClosedForRead)
}

@Test
fun readSecond() = runTest {
val first = channelReplay.replay()
val second = channelReplay.replay()
assertRead(second)
assertTrue(second.isClosedForRead)
assertTrue(first.isClosedForRead)
}

@Test
fun readABunch() = runTest {
val jobs = (0..10).map {
launch {
val readChannel = channelReplay.replay()
yield()
try {
assertRead(readChannel)
} catch (e: Exception) {
assertEquals("Save body abandoned", e.message)
}
assertTrue(readChannel.isClosedForRead || readChannel.exhausted())
}
}
joinAll(*jobs.toTypedArray())
}

private suspend fun assertRead(readChannel: ByteReadChannel) {
repeat(size) { i ->
assertEquals(expectedByte, readChannel.readByte(), "Incorrect byte at index $i")
}
}

}

0 comments on commit 38caba7

Please sign in to comment.