Skip to content

Commit

Permalink
Cherry-pick common ClientLoader implementation from 3.1.0-eap (#4550)
Browse files Browse the repository at this point in the history
* Add nonJvm source set
* Commonize ClientLoader and make it work with multiple engines on js/wasmJs

---------

Co-authored-by: Oleg Yukhnevich <[email protected]>
  • Loading branch information
osipxd and whyoleg authored Dec 17, 2024
1 parent d35aeb9 commit dbcbff9
Show file tree
Hide file tree
Showing 27 changed files with 308 additions and 348 deletions.
14 changes: 10 additions & 4 deletions buildSrc/src/main/kotlin/TargetsConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ import java.io.*

private val Project.files: Array<File> get() = project.projectDir.listFiles() ?: emptyArray()
val Project.hasCommon: Boolean get() = files.any { it.name == "common" }
val Project.hasNonJvm: Boolean get() = files.any { it.name == "nonJvm" }
val Project.hasJvmAndPosix: Boolean get() = hasCommon || files.any { it.name == "jvmAndPosix" }
val Project.hasPosix: Boolean get() = hasCommon || hasJvmAndPosix || files.any { it.name == "posix" }
val Project.hasPosix: Boolean get() = hasCommon || hasNonJvm || hasJvmAndPosix || files.any { it.name == "posix" }
val Project.hasDesktop: Boolean get() = hasPosix || files.any { it.name == "desktop" }
val Project.hasNix: Boolean get() = hasPosix || files.any { it.name == "nix" }
val Project.hasLinux: Boolean get() = hasNix || files.any { it.name == "linux" }
val Project.hasDarwin: Boolean get() = hasNix || files.any { it.name == "darwin" }
val Project.hasAndroidNative: Boolean get() = hasPosix || files.any { it.name == "androidNative" }
val Project.hasWindows: Boolean get() = hasPosix || files.any { it.name == "windows" }
val Project.hasJsAndWasmShared: Boolean get() = files.any { it.name == "jsAndWasmShared" }
val Project.hasJs: Boolean get() = hasCommon || files.any { it.name == "js" } || hasJsAndWasmShared
val Project.hasWasmJs: Boolean get() = hasCommon || files.any { it.name == "wasmJs" } || hasJsAndWasmShared
val Project.hasJsAndWasmShared: Boolean get() = hasCommon || hasNonJvm || files.any { it.name == "jsAndWasmShared" }
val Project.hasJs: Boolean get() = hasJsAndWasmShared || files.any { it.name == "js" }
val Project.hasWasmJs: Boolean get() = hasJsAndWasmShared || files.any { it.name == "wasmJs" }
val Project.hasJvm: Boolean get() = hasCommon || hasJvmAndPosix || files.any { it.name == "jvm" }

val Project.hasExplicitNative: Boolean
Expand Down Expand Up @@ -114,6 +115,11 @@ private val hierarchyTemplate = KotlinHierarchyTemplate {
group("windows")
group("macos")
}

group("nonJvm") {
group("posix")
group("jsAndWasmShared")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import kotlin.test.assertFailsWith
class AuthTest : ClientLoader() {

@Test
fun testDigestAuthLegacy() = clientTests(listOf("Js", "native")) {
fun testDigestAuthLegacy() = clientTests(listOf("Js", "native:*")) {
config {
install(Auth) {
digest {
Expand All @@ -43,7 +43,7 @@ class AuthTest : ClientLoader() {
}

@Test
fun testDigestAuth() = clientTests(listOf("Js", "native")) {
fun testDigestAuth() = clientTests(listOf("Js", "native:*")) {
config {
install(Auth) {
digest {
Expand All @@ -60,7 +60,7 @@ class AuthTest : ClientLoader() {
}

@Test
fun testDigestAuthPerRealm() = clientTests(listOf("Js", "native")) {
fun testDigestAuthPerRealm() = clientTests(listOf("Js", "native:*")) {
config {
install(Auth) {
digest {
Expand All @@ -84,7 +84,7 @@ class AuthTest : ClientLoader() {
}

@Test
fun testDigestAuthSHA256() = clientTests(listOf("Js", "native")) {
fun testDigestAuthSHA256() = clientTests(listOf("Js", "native:*")) {
config {
install(Auth) {
digest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
package io.ktor.client.tests.utils

import io.ktor.client.engine.*
import kotlinx.coroutines.test.*
import kotlinx.coroutines.test.TestResult
import kotlinx.coroutines.test.runTest
import kotlin.time.Duration
import kotlin.time.Duration.Companion.minutes

internal expect val enginesToTest: Iterable<HttpClientEngineFactory<HttpClientEngineConfig>>
internal expect val platformName: String
internal expect fun platformDumpCoroutines()
internal expect fun platformWaitForAllCoroutines()

/**
* Helper interface to test client.
*/
expect abstract class ClientLoader(timeoutSeconds: Int = 60) {
abstract class ClientLoader(private val timeout: Duration = 1.minutes) {
/**
* Perform test against all clients from dependencies.
*/
Expand All @@ -19,10 +27,110 @@ expect abstract class ClientLoader(timeoutSeconds: Int = 60) {
onlyWithEngine: String? = null,
retries: Int = 1,
block: suspend TestClientBuilder<HttpClientEngineConfig>.() -> Unit
): TestResult
): TestResult = runTest(timeout = timeout) {
val skipPatterns = skipEngines.map(SkipEnginePattern::parse)

val failures: List<TestFailure> = enginesToTest.mapNotNull { engineFactory ->
val engineName = engineFactory.engineName

if (shouldRun(engineName, skipPatterns, onlyWithEngine)) {
try {
println("Run test with engine $engineName")
// run test here
performTestWithEngine(engineFactory, this@ClientLoader, retries, block)
null // engine test passed
} catch (cause: Throwable) {
// engine test failed, save failure to report after run for every engine.
TestFailure(engineName, cause)
}
} else {
println("Skipping test with engine $engineName")
null // engine skipped
}
}

if (failures.isNotEmpty()) {
val message = buildString {
appendLine("Test failed for engines: ${failures.map { it.engineName }}")
failures.forEach {
appendLine("Test failed for engine '$platformName:${it.engineName}' with:")
appendLine(it.cause.stackTraceToString().prependIndent(" "))
}
}
throw AssertionError(message)
}
}

private fun shouldRun(
engineName: String,
skipEnginePatterns: List<SkipEnginePattern>,
onlyWithEngine: String?
): Boolean {
val lowercaseEngineName = engineName.lowercase()
if (onlyWithEngine != null && onlyWithEngine.lowercase() != lowercaseEngineName) return false

skipEnginePatterns.forEach {
if (it.matches(lowercaseEngineName)) return false
}

return true
}

/**
* Print coroutines in debug mode.
*/
fun dumpCoroutines()
fun dumpCoroutines(): Unit = platformDumpCoroutines()

// Issues to fix before unlocking:
// 1. Pinger & Ponger in ws
// 2. Nonce generator
// @After
fun waitForAllCoroutines(): Unit = platformWaitForAllCoroutines()
}

internal val HttpClientEngineFactory<*>.engineName: String
get() = this::class.simpleName!!

private data class SkipEnginePattern(
val skippedPlatform: String?, // null means * or empty
val skippedEngine: String?, // null means * or empty
) {
fun matches(engineName: String): Boolean {
var result = true
if (skippedEngine != null) {
result = result && engineName == skippedEngine
}
if (result && skippedPlatform != null) {
result = result && platformName.startsWith(skippedPlatform)
}
return result
}

companion object {
fun parse(pattern: String): SkipEnginePattern {
val parts = pattern.lowercase().split(":").map { it.takeIf { it != "*" } }
val platform: String?
val engine: String?
when (parts.size) {
1 -> {
platform = null
engine = parts[0]
}

2 -> {
platform = parts[0]
engine = parts[1]
}

else -> error("Skip engine pattern should consist of two parts: PLATFORM:ENGINE or ENGINE")
}

if (platform == null && engine == null) {
error("Skip engine pattern should consist of two parts: PLATFORM:ENGINE or ENGINE")
}
return SkipEnginePattern(platform, engine)
}
}
}

private class TestFailure(val engineName: String, val cause: Throwable)
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

@file:Suppress("NO_EXPLICIT_RETURN_TYPE_IN_API_MODE_WARNING", "KDocMissingDocumentation")

package io.ktor.client.tests.utils

import io.ktor.client.*
import io.ktor.client.engine.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.*
import kotlinx.coroutines.test.runTest
import kotlin.time.Duration.Companion.milliseconds

/**
Expand Down Expand Up @@ -65,14 +63,23 @@ private fun testWithClient(
/**
* Perform test with selected client engine [factory].
*/
@OptIn(DelicateCoroutinesApi::class)
fun <T : HttpClientEngineConfig> testWithEngine(
factory: HttpClientEngineFactory<T>,
loader: ClientLoader? = null,
timeoutMillis: Long = 60L * 1000L,
retries: Int = 1,
block: suspend TestClientBuilder<T>.() -> Unit
) = runTest(timeout = timeoutMillis.milliseconds) {
performTestWithEngine(factory, loader, retries, block)
}

@OptIn(DelicateCoroutinesApi::class)
suspend fun <T : HttpClientEngineConfig> performTestWithEngine(
factory: HttpClientEngineFactory<T>,
loader: ClientLoader? = null,
retries: Int = 1,
block: suspend TestClientBuilder<T>.() -> Unit
) {
val builder = TestClientBuilder<T>().apply { block() }

if (builder.dumpAfterDelay > 0 && loader != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private val TEST_ARRAY = ByteArray(8 * 1025) { 1 }
private val TEST_NAME = "123".repeat(5000)

@OptIn(DelicateCoroutinesApi::class)
class BodyProgressTest : ClientLoader(timeoutSeconds = 60) {
class BodyProgressTest : ClientLoader() {

@Serializable
data class User(val login: String, val id: Long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import io.ktor.http.content.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.*
import kotlinx.io.*
import kotlinx.coroutines.cancel
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.io.readByteArray
import kotlin.test.*
import kotlin.time.Duration.Companion.minutes

Expand All @@ -42,7 +43,7 @@ val testArrays = testSize.map {
makeArray(it)
}

class ContentTest : ClientLoader(5 * 60) {
class ContentTest : ClientLoader(timeout = 5.minutes) {

@Test
fun testGetFormData() = clientTests {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class EventsTest : ClientLoader() {
}

@Test
fun testRedirectEvent() = clientTests(listOf("js")) {
fun testRedirectEvent() = clientTests(listOf("Js")) {
test { client ->
counter.value = 0
client.monitor.subscribe(HttpResponseRedirectEvent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.tests.utils.*
import io.ktor.http.*
import io.ktor.utils.io.errors.*
import kotlinx.coroutines.*
import kotlinx.io.IOException
import kotlin.math.*
import kotlin.test.*

class HttpRequestRetryTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class HttpTimeoutTest : ClientLoader() {
// Fix https://youtrack.jetbrains.com/issue/KTOR-7885
@Ignore
@Test
fun testRedirect() = clientTests(listOf("js"), retries = 5) {
fun testRedirect() = clientTests(listOf("Js"), retries = 5) {
config {
install(HttpTimeout) { requestTimeoutMillis = 10000 }
}
Expand All @@ -344,7 +344,7 @@ class HttpTimeoutTest : ClientLoader() {

// Js can't configure test timeout in browser
@Test
fun testRedirectPerRequestAttributes() = clientTests(listOf("js")) {
fun testRedirectPerRequestAttributes() = clientTests(listOf("Js")) {
config {
install(HttpTimeout)
}
Expand Down Expand Up @@ -429,7 +429,7 @@ class HttpTimeoutTest : ClientLoader() {
}

@Test
fun testConnectionRefusedException() = clientTests(listOf("Js", "native:*", "win:*")) {
fun testConnectionRefusedException() = clientTests(listOf("Js", "native:*", "jvm/win:*")) {
config {
install(HttpTimeout) { connectTimeoutMillis = 1000 }
}
Expand All @@ -445,7 +445,7 @@ class HttpTimeoutTest : ClientLoader() {
}

@Test
fun testSocketTimeoutRead() = clientTests(listOf("Js", "native:CIO", "Java")) {
fun testSocketTimeoutRead() = clientTests(listOf("Js", "native:CIO", "Curl", "Java")) {
config {
install(HttpTimeout) { socketTimeoutMillis = 1000 }
}
Expand All @@ -460,7 +460,9 @@ class HttpTimeoutTest : ClientLoader() {
}

@Test
fun testSocketTimeoutReadPerRequestAttributes() = clientTests(listOf("Js", "native:CIO", "Java", "Apache5")) {
fun testSocketTimeoutReadPerRequestAttributes() = clientTests(
listOf("Js", "native:CIO", "Curl", "Java", "Apache5")
) {
config {
install(HttpTimeout)
}
Expand All @@ -477,7 +479,7 @@ class HttpTimeoutTest : ClientLoader() {
}

@Test
fun testSocketTimeoutWriteFailOnWrite() = clientTests(listOf("Js", "Android", "native:CIO", "Java")) {
fun testSocketTimeoutWriteFailOnWrite() = clientTests(listOf("Js", "Android", "native:CIO", "Curl", "Java")) {
config {
install(HttpTimeout) { socketTimeoutMillis = 500 }
}
Expand All @@ -491,7 +493,7 @@ class HttpTimeoutTest : ClientLoader() {

@Test
fun testSocketTimeoutWriteFailOnWritePerRequestAttributes() = clientTests(
listOf("Js", "Android", "Apache5", "native:CIO", "Java")
listOf("Js", "Android", "Apache5", "native:CIO", "Curl", "Java")
) {
config {
install(HttpTimeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class JsonTest : ClientLoader() {
data class Result<T>(val message: String, val data: T)

@Test
fun testUserGenerics() = clientTests(listOf("js")) {
fun testUserGenerics() = clientTests(listOf("Js")) {
config {
install(ContentNegotiation) { json() }
}
Expand Down
Loading

0 comments on commit dbcbff9

Please sign in to comment.