Skip to content

Commit

Permalink
Add tests with ContentEncoding plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Stexxe committed Dec 17, 2024
1 parent 6f02606 commit 5796d6b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ kotlin.sourceSets {
jvmTest {
dependencies {
api(project(":ktor-shared:ktor-serialization:ktor-serialization-jackson"))
api(project(":ktor-client:ktor-client-plugins:ktor-client-encoding"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,39 +204,42 @@ public val Logging: ClientPlugin<LoggingConfig> = createClientPlugin("Logging",
}
}

on(ResponseHook) { response ->
if (stdFormat) {
val (size, channel) = calcResponseBodySize(response, response.headers)

if (channel != response.rawContent) {
proceedWith(object : HttpResponse() {
override val call: HttpClientCall
get() = response.call
override val status: HttpStatusCode
get() = response.status
override val version: HttpProtocolVersion
get() = response.version
override val requestTime: GMTDate
get() = response.requestTime
override val responseTime: GMTDate
get() = response.responseTime

@InternalAPI
override val rawContent: ByteReadChannel
get() = channel
override val headers: Headers
get() = response.headers
override val coroutineContext: CoroutineContext
get() = response.coroutineContext
})
}

val request = response.request
val duration = response.responseTime.timestamp - response.requestTime.timestamp
logger.log("<-- ${response.status} ${request.url.pathQuery()} ${response.version} (${duration}ms, $size-byte body)")
return@on
on(ResponseAfterEncodingHook) { response ->
if (!stdFormat) return@on

val (size, channel) = calcResponseBodySize(response, response.headers)

if (channel != response.rawContent) {
proceedWith(object : HttpResponse() {
override val call: HttpClientCall
get() = response.call
override val status: HttpStatusCode
get() = response.status
override val version: HttpProtocolVersion
get() = response.version
override val requestTime: GMTDate
get() = response.requestTime
override val responseTime: GMTDate
get() = response.responseTime

@InternalAPI
override val rawContent: ByteReadChannel
get() = channel
override val headers: Headers
get() = response.headers
override val coroutineContext: CoroutineContext
get() = response.coroutineContext
})
}

val request = response.request
val duration = response.responseTime.timestamp - response.requestTime.timestamp
logger.log("<-- ${response.status} ${request.url.pathQuery()} ${response.version} (${duration}ms, $size-byte body)")
}

on(ResponseHook) { response ->
if (stdFormat) return@on

if (level == LogLevel.NONE || response.call.attributes.contains(DisableLogging)) return@on

val callLogger = response.call.attributes[ClientCallLogger]
Expand Down Expand Up @@ -393,7 +396,6 @@ private object ResponseHook : ClientHook<suspend ResponseHook.Context.(response:

class Context(private val context: PipelineContext<HttpResponse, Unit>) {
suspend fun proceed() = context.proceed()
suspend fun proceedWith(response: HttpResponse) = context.proceedWith(response)
}

override fun install(
Expand All @@ -406,6 +408,24 @@ private object ResponseHook : ClientHook<suspend ResponseHook.Context.(response:
}
}

private object ResponseAfterEncodingHook : ClientHook<suspend ResponseAfterEncodingHook.Context.(response: HttpResponse) -> Unit> {

class Context(private val context: PipelineContext<HttpResponse, Unit>) {
suspend fun proceedWith(response: HttpResponse) = context.proceedWith(response)
}

override fun install(
client: HttpClient,
handler: suspend Context.(response: HttpResponse) -> Unit
) {
val afterState = PipelinePhase("AfterState")
client.receivePipeline.insertPhaseAfter(HttpReceivePipeline.State, afterState)
client.receivePipeline.intercept(afterState) {
handler(Context(this), subject)
}
}
}

private object SendHook : ClientHook<suspend SendHook.Context.(response: HttpRequestBuilder) -> Unit> {

class Context(private val context: PipelineContext<Any, HttpRequestBuilder>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package io.ktor.client.plugins.logging

import io.ktor.client.*
import io.ktor.client.engine.mock.*
import io.ktor.client.plugins.compression.ContentEncoding
import io.ktor.client.request.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.Headers
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.content.OutgoingContent
import io.ktor.util.GZipEncoder
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import io.ktor.utils.io.writeStringUtf8
Expand Down Expand Up @@ -210,6 +212,31 @@ class NewFormatTest {
}
}

@Test
fun basicPostWithGzip() = runTest {
HttpClient(MockEngine) {
install(Logging) {
level = LogLevel.INFO
logger = log
standardFormat = true
}
install(ContentEncoding) { gzip() }

engine {
addHandler {
val channel = GZipEncoder.encode(ByteReadChannel("a".repeat(1024)))
respond(channel, headers = Headers.build { append(HttpHeaders.ContentEncoding, "gzip") })
}
}
}.use { client ->
client.post("/")

log.assertLogEqual("--> POST / (0-byte body)")
.assertLogMatch(Regex("""<-- 200 OK / HTTP/1.1 \(\d+ms, 1024-byte body\)"""))
.assertNoMoreLogs()
}
}

private fun testWithLevel(lvl: LogLevel, handle: MockRequestHandler, test: suspend (HttpClient) -> Unit) = runTest {
HttpClient(MockEngine) {
install(Logging) {
Expand Down

0 comments on commit 5796d6b

Please sign in to comment.